Skip to content

Commit f4c12c8

Browse files
committed
PoC: run pipeline step by step
1 parent eed1a04 commit f4c12c8

File tree

3 files changed

+156
-19
lines changed

3 files changed

+156
-19
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
5+
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
6+
from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
7+
from neo4j_graphrag.experimental.components.kg_writer import Neo4jWriter
8+
from neo4j_graphrag.experimental.components.lexical_graph import LexicalGraphBuilder
9+
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
10+
FixedSizeSplitter,
11+
)
12+
from neo4j_graphrag.experimental.components.types import LexicalGraphConfig
13+
from neo4j_graphrag.experimental.pipeline import Pipeline
14+
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult
15+
16+
import neo4j
17+
18+
19+
async def main(neo4j_driver: neo4j.Driver) -> PipelineResult:
20+
"""This is where we define and run the Lexical Graph builder pipeline, instantiating
21+
a few components:
22+
23+
- Text Splitter: to split the text into manageable chunks of fixed size
24+
- Chunk Embedder: to embed the chunks' text
25+
- Lexical Graph Builder: to build the lexical graph, ie creating the chunk nodes and relationships between them
26+
- KG writer: save the lexical graph to Neo4j
27+
"""
28+
pipe = Pipeline()
29+
# define the components
30+
pipe.add_component(
31+
FixedSizeSplitter(chunk_size=200, chunk_overlap=10, approximate=False),
32+
"splitter",
33+
)
34+
# optional: define some custom node labels for the lexical graph:
35+
lexical_graph_config = LexicalGraphConfig(
36+
chunk_node_label="TextPart",
37+
)
38+
pipe.add_component(
39+
LexicalGraphBuilder(lexical_graph_config),
40+
"lexical_graph_builder",
41+
)
42+
pipe.add_component(Neo4jWriter(neo4j_driver), "writer")
43+
# define the execution order of component
44+
# and how the output of previous components must be used
45+
pipe.connect(
46+
"splitter",
47+
"lexical_graph_builder",
48+
input_config={"text_chunks": "splitter"},
49+
)
50+
pipe.connect(
51+
"lexical_graph_builder",
52+
"writer",
53+
input_config={
54+
"graph": "lexical_graph_builder.graph",
55+
"lexical_graph_config": "lexical_graph_builder.config",
56+
},
57+
)
58+
# user input:
59+
# the initial text
60+
# and the list of entities and relations we are looking for
61+
pipe_inputs = {
62+
"splitter": {
63+
"text": """Albert Einstein was a German physicist born in 1879 who
64+
wrote many groundbreaking papers especially about general relativity
65+
and quantum mechanics. He worked for many different institutions, including
66+
the University of Bern in Switzerland and the University of Oxford."""
67+
},
68+
"lexical_graph_builder": {
69+
"document_info": {
70+
# 'path' can be anything
71+
"path": "example/lexical_graph_from_text.py"
72+
},
73+
},
74+
}
75+
# run the pipeline
76+
async for step_result in pipe.run_step_by_step(pipe_inputs):
77+
print(step_result)
78+
await asyncio.sleep(2)
79+
80+
81+
if __name__ == "__main__":
82+
with neo4j.GraphDatabase.driver(
83+
"bolt://localhost:7687", auth=("neo4j", "password")
84+
) as driver:
85+
print(asyncio.run(main(driver)))

src/neo4j_graphrag/experimental/pipeline/orchestrator.py

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,13 @@
1515
from __future__ import annotations
1616

1717
import asyncio
18+
import enum
1819
import logging
1920
import uuid
2021
import warnings
22+
import datetime
23+
24+
from pydantic import BaseModel, Field
2125
from typing import TYPE_CHECKING, Any, AsyncGenerator
2226

2327
from neo4j_graphrag.experimental.pipeline.exceptions import (
@@ -35,6 +39,20 @@
3539
logger = logging.getLogger(__name__)
3640

3741

42+
class ResultType(enum.Enum):
43+
TASK_CHECKPOINT = "TASK_CHECKPOINT"
44+
TASK_FINISHED = "TASK_FINISHED"
45+
PIPELINE_FINISHED = "PIPELINE_FINISHED"
46+
47+
48+
class Result(BaseModel):
49+
result_type: ResultType
50+
data: Any
51+
timestamp: datetime.datetime = Field(
52+
default_factory=lambda: datetime.datetime.now(datetime.timezone.utc)
53+
)
54+
55+
3856
class Orchestrator:
3957
"""Orchestrate a pipeline.
4058
@@ -53,17 +71,7 @@ def __init__(self, pipeline: "Pipeline"):
5371
self.event_notifier = EventNotifier(pipeline.callback)
5472
self.run_id = str(uuid.uuid4())
5573

56-
async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None:
57-
"""Get inputs and run a specific task. Once the task is done,
58-
calls the on_task_complete method.
59-
60-
Args:
61-
task (TaskPipelineNode): The task to be run
62-
data (dict[str, Any]): The pipeline input data
63-
64-
Returns:
65-
None
66-
"""
74+
async def run_single_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> Result:
6775
param_mapping = self.get_input_config_for_task(task)
6876
inputs = await self.get_component_inputs(task.name, param_mapping, data)
6977
try:
@@ -72,13 +80,31 @@ async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None:
7280
logger.debug(
7381
f"ORCHESTRATOR: TASK ABORTED: {task.name} is already running or done, aborting"
7482
)
75-
return None
83+
raise StopAsyncIteration()
7684
await self.event_notifier.notify_task_started(self.run_id, task.name, inputs)
7785
res = await task.run(inputs)
7886
await self.set_task_status(task.name, RunStatus.DONE)
7987
await self.event_notifier.notify_task_finished(self.run_id, task.name, res)
8088
if res:
81-
await self.on_task_complete(data=data, task=task, result=res)
89+
await self.save_results(task=task, result=res)
90+
return Result(result_type=ResultType.TASK_FINISHED, data=res.result)
91+
92+
async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> AsyncGenerator[Result, None]:
93+
"""Get inputs and run a specific task. Once the task is done,
94+
calls the on_task_complete method.
95+
96+
Args:
97+
task (TaskPipelineNode): The task to be run
98+
data (dict[str, Any]): The pipeline input data
99+
100+
Returns:
101+
None
102+
"""
103+
yield await self.run_single_task(task, data)
104+
# then get the next tasks to be executed and run them
105+
async for n in self.next(task):
106+
async for res in self.run_task(n, data):
107+
yield res
82108

83109
async def set_task_status(self, task_name: str, status: RunStatus) -> None:
84110
"""Set a new status
@@ -102,8 +128,8 @@ async def set_task_status(self, task_name: str, status: RunStatus) -> None:
102128
self.run_id, task_name, status.value
103129
)
104130

105-
async def on_task_complete(
106-
self, data: dict[str, Any], task: TaskPipelineNode, result: RunResult
131+
async def save_results(
132+
self, task: TaskPipelineNode, result: RunResult
107133
) -> None:
108134
"""When a given task is complete, it will call this method
109135
to find the next tasks to run.
@@ -115,9 +141,6 @@ async def on_task_complete(
115141
await self.add_result_for_component(
116142
task.name, res_to_save, is_final=task.is_leaf()
117143
)
118-
# then get the next tasks to be executed
119-
# and run them in //
120-
await asyncio.gather(*[self.run_task(n, data) async for n in self.next(task)])
121144

122145
async def check_dependencies_complete(self, task: TaskPipelineNode) -> None:
123146
"""Check that all parent tasks are complete.
@@ -257,3 +280,14 @@ async def run(self, data: dict[str, Any]) -> None:
257280
await self.event_notifier.notify_pipeline_finished(
258281
self.run_id, await self.pipeline.get_final_results(self.run_id)
259282
)
283+
284+
async def run_step_by_step(self, data: dict[str, Any]) -> AsyncGenerator[Result, None]:
285+
"""Run the pipline, starting from the root nodes
286+
(node without any parent). Then the callback on_task_complete
287+
will handle the task dependencies.
288+
"""
289+
for root in self.pipeline.roots():
290+
async for res in self.run_task(root, data):
291+
yield res
292+
# tasks = [self.run_task(root, data) for root in self.pipeline.roots()]
293+
# await asyncio.gather(*tasks)

src/neo4j_graphrag/experimental/pipeline/pipeline.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import warnings
1919
from collections import defaultdict
2020
from timeit import default_timer
21-
from typing import Any, Optional
21+
from typing import Any, Optional, AsyncGenerator
2222

2323
from neo4j_graphrag.utils.logging import prettify
2424

@@ -420,3 +420,21 @@ async def run(self, data: dict[str, Any]) -> PipelineResult:
420420
run_id=orchestrator.run_id,
421421
result=await self.get_final_results(orchestrator.run_id),
422422
)
423+
424+
async def run_step_by_step(self, data: dict[str, Any]) -> AsyncGenerator[PipelineResult, None]:
425+
logger.debug("PIPELINE START")
426+
start_time = default_timer()
427+
self.invalidate()
428+
self.validate_input_data(data)
429+
orchestrator = Orchestrator(self)
430+
logger.debug(f"PIPELINE ORCHESTRATOR: {orchestrator.run_id}")
431+
async for res in orchestrator.run_step_by_step(data):
432+
yield res
433+
end_time = default_timer()
434+
logger.debug(
435+
f"PIPELINE FINISHED {orchestrator.run_id} in {end_time - start_time}s"
436+
)
437+
yield PipelineResult(
438+
run_id=orchestrator.run_id,
439+
result=await self.get_final_results(orchestrator.run_id),
440+
)

0 commit comments

Comments
 (0)