Skip to content

Commit c48c3f3

Browse files
committed
Pipeline orchestration with async support
1 parent 86554af commit c48c3f3

File tree

9 files changed

+309
-164
lines changed

9 files changed

+309
-164
lines changed

poetry.lock

Lines changed: 27 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ weaviate-client = {version = "^4.6.1", optional = true}
3434
pinecone-client = {version = "^4.1.0", optional = true}
3535
types-mock = "^5.1.0.20240425"
3636
eval-type-backport = "^0.2.0"
37+
jsonpath-ng = "^1.6.1"
3738

3839
[tool.poetry.group.dev.dependencies]
3940
pylint = "^3.1.0"
Lines changed: 31 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,87 +1,74 @@
11
import asyncio
2+
from typing import Any
23
from neo4j_genai.core.pipeline import Component
34

45

56
class DocumentChunker(Component):
6-
def process(self, text: str):
7-
return {
8-
"chunks": [t.strip() for t in text.split(".") if t.strip()]
9-
}
7+
async def process(self, text: str) -> dict[str, Any]:
8+
return {"chunks": [t.strip() for t in text.split(".") if t.strip()]}
109

1110

1211
class SchemaBuilder(Component):
13-
def process(self, schema: dict):
12+
async def process(self, schema: dict[str, Any]) -> dict[str, Any]:
1413
return {"schema": schema}
1514

1615

1716
class ERExtractor(Component):
18-
def process(self, chunks: list[str], schema: str) -> dict:
17+
async def _process_chunk(self, chunk: str, schema: str) -> dict[str, Any]:
1918
return {
2019
"data": {
2120
"entities": [{"label": "Person", "properties": {"name": "John Doe"}}],
22-
"relations": []
21+
"relations": [],
2322
}
2423
}
2524

26-
async def _process_chunk(self, chunk: str, schema: str):
27-
return {
28-
"data": {
29-
"entities": [{"label": "Person", "properties": {"name": "John Doe"}}],
30-
"relations": []
31-
}
32-
}
33-
34-
async def aprocess(self, chunks: list[str], schema: str) -> dict:
35-
tasks = [
36-
self._process_chunk(chunk, schema)
37-
for chunk in chunks
38-
]
25+
async def process(self, chunks: list[str], schema: str) -> dict[str, Any]:
26+
tasks = [self._process_chunk(chunk, schema) for chunk in chunks]
3927
result = await asyncio.gather(*tasks)
40-
merged_result = {"data": {"entities": [], "relations": []}}
28+
merged_result: dict[str, Any] = {"data": {"entities": [], "relations": []}}
4129
for res in result:
4230
merged_result["data"]["entities"] += res["data"]["entities"]
4331
merged_result["data"]["relations"] += res["data"]["relations"]
4432
return merged_result
4533

4634

4735
class Writer(Component):
48-
def process(self, entities: dict, relations: dict) -> dict:
36+
async def process(
37+
self, entities: dict[str, Any], relations: dict[str, Any]
38+
) -> dict[str, Any]:
4939
return {
5040
"status": "OK",
5141
"entities": entities,
5242
"relations": relations,
5343
}
5444

5545

56-
if __name__ == '__main__':
46+
if __name__ == "__main__":
5747
from neo4j_genai.core.pipeline import Pipeline
5848

5949
pipe = Pipeline()
6050
pipe.add_component("chunker", DocumentChunker())
61-
pipe.add_component("schema", SchemaBuilder())
51+
pipe.add_component("schema", SchemaBuilder())
6252
pipe.add_component("extractor", ERExtractor())
6353
pipe.add_component("writer", Writer())
64-
pipe.connect("chunker", "extractor", input_defs={
65-
"chunks": "chunker.chunks"
66-
})
67-
pipe.connect("schema", "extractor", input_defs={
68-
"schema": "schema.schema"
69-
})
70-
pipe.connect("extractor", "writer", input_defs={
71-
"entities": "extractor.data.entities",
72-
"relations": "extractor.data.relations",
73-
})
54+
pipe.connect("chunker", "extractor", input_defs={"chunks": "chunker.chunks"})
55+
pipe.connect("schema", "extractor", input_defs={"schema": "schema.schema"})
56+
pipe.connect(
57+
"extractor",
58+
"writer",
59+
input_defs={
60+
"entities": "extractor.data.entities",
61+
"relations": "extractor.data.relations",
62+
},
63+
)
7464

7565
pipe_inputs = {
76-
"chunker":
77-
{
78-
"text": "Graphs are everywhere. "
79-
"GraphRAG is the future of Artificial Intelligence. "
80-
"Robots are already running the world."
81-
},
82-
"schema": {
83-
"schema": "Person OWNS House"
84-
}
66+
"chunker": {
67+
"text": "Graphs are everywhere. "
68+
"GraphRAG is the future of Artificial Intelligence. "
69+
"Robots are already running the world."
70+
},
71+
"schema": {"schema": "Person OWNS House"},
8572
}
8673
# print(pipe.run_all(pipe_inputs))
87-
print(asyncio.run(pipe.arun_all(pipe_inputs)))
74+
print(asyncio.run(pipe.run(pipe_inputs)))

src/neo4j_genai/components/rag.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from typing import Any
23

34
from neo4j_genai.types import RetrieverResult, RetrieverResultItem
@@ -13,18 +14,18 @@ def search(self, *args: Any, **kwargs: Any) -> RetrieverResult:
1314
]
1415
)
1516

16-
def process(self, query: str):
17+
async def process(self, query: str) -> dict[str, Any]:
1718
res = self.search(query)
1819
return {"context": "\n".join(c.content for c in res.items)}
1920

2021

2122
class PromptTemplate(Component):
22-
def process(self, query: str, context: list):
23+
async def process(self, query: str, context: list[str]) -> dict[str, Any]:
2324
return {"prompt": f"my prompt using '{context}', query '{query}'"}
2425

2526

2627
class LLM(Component):
27-
def process(self, prompt):
28+
async def process(self, prompt: str) -> dict[str, Any]:
2829
return {"answer": f"some text based on '{prompt}'"}
2930

3031

@@ -37,4 +38,8 @@ def process(self, prompt):
3738
pipe.connect("augment", "generate", {"prompt": "augment.prompt"})
3839

3940
query = "my question"
40-
print(pipe.run_all({"retrieve": {"query": query}, "augment": {"query": query}}))
41+
print(
42+
asyncio.run(
43+
pipe.run({"retrieve": {"query": query}, "augment": {"query": query}})
44+
)
45+
)

src/neo4j_genai/core/graph.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,42 @@
11
"""
22
Basic graph structure for Pipeline
33
"""
4-
from typing import Optional
4+
5+
from typing import Optional, Any
6+
57

68
class Node:
7-
def __init__(self, name: str, data: dict) -> None:
9+
def __init__(self, name: str, data: dict[str, Any]) -> None:
810
self.name = name
911
self.data = data
10-
self.parents = []
11-
self.children = []
12+
self.parents: list[Node] = []
13+
self.children: list[Node] = []
1214

13-
def is_root(self):
15+
def is_root(self) -> bool:
1416
return len(self.parents) == 0
1517

16-
def is_leaf(self):
18+
def is_leaf(self) -> bool:
1719
return len(self.children) == 0
1820

1921

2022
class Edge:
21-
def __init__(self, start: Node, end: Node, data: Optional[dict] = None):
23+
def __init__(self, start: Node, end: Node, data: Optional[dict[str, Any]] = None):
2224
self.start = start
2325
self.end = end
2426
self.data = data
2527

2628

2729
class Graph:
28-
def __init__(self):
29-
self._nodes = {}
30-
self._edges = []
30+
def __init__(self) -> None:
31+
self._nodes: dict[str, Node] = {}
32+
self._edges: list[Edge] = []
3133

32-
def add_node(self, node: Node):
34+
def add_node(self, node: Node) -> None:
3335
if node in self:
3436
raise ValueError(f"Node {node.name} already exists")
3537
self._nodes[node.name] = node
3638

37-
def connect(self, start: Node, end: Node, data: dict):
39+
def connect(self, start: Node, end: Node, data: dict[str, Any]) -> None:
3840
self._edges.append(Edge(start, end, data))
3941
self._nodes[end.name].parents.append(start)
4042
self._nodes[start.name].children.append(end)
@@ -43,28 +45,28 @@ def get_node_by_name(self, name: str, raise_exception: bool = False) -> Node:
4345
node = self._nodes.get(name)
4446
if node is None and raise_exception:
4547
raise KeyError(f"Component {name} not in graph")
46-
return node
48+
return node # type: ignore
4749

48-
def roots(self):
50+
def roots(self) -> list[Node]:
4951
root = []
5052
for node in self._nodes.values():
5153
if node.is_root():
5254
root.append(node)
5355
return root
5456

55-
def next_edges(self, node):
57+
def next_edges(self, node: Node) -> list[Edge]:
5658
res = []
5759
for edge in self._edges:
5860
if edge.start == node:
5961
res.append(edge)
6062
return res
6163

62-
def previous_edges(self, node):
64+
def previous_edges(self, node: Node) -> list[Edge]:
6365
res = []
6466
for edge in self._edges:
6567
if edge.end == node:
6668
res.append(edge)
6769
return res
6870

69-
def __contains__(self, node: Node):
71+
def __contains__(self, node: Node) -> bool:
7072
return node.name in self._nodes

0 commit comments

Comments
 (0)