Skip to content

PoC: run pipeline step by step #270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from __future__ import annotations

import asyncio

from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
from neo4j_graphrag.experimental.components.kg_writer import Neo4jWriter
from neo4j_graphrag.experimental.components.lexical_graph import LexicalGraphBuilder
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
FixedSizeSplitter,
)
from neo4j_graphrag.experimental.components.types import LexicalGraphConfig
from neo4j_graphrag.experimental.pipeline import Pipeline
from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult

import neo4j


async def main(neo4j_driver: neo4j.Driver) -> PipelineResult:
"""This is where we define and run the Lexical Graph builder pipeline, instantiating
a few components:

- Text Splitter: to split the text into manageable chunks of fixed size
- Chunk Embedder: to embed the chunks' text
- Lexical Graph Builder: to build the lexical graph, ie creating the chunk nodes and relationships between them
- KG writer: save the lexical graph to Neo4j
"""
pipe = Pipeline()
# define the components
pipe.add_component(
FixedSizeSplitter(chunk_size=200, chunk_overlap=10, approximate=False),
"splitter",
)
# optional: define some custom node labels for the lexical graph:
lexical_graph_config = LexicalGraphConfig(
chunk_node_label="TextPart",
)
pipe.add_component(
LexicalGraphBuilder(lexical_graph_config),
"lexical_graph_builder",
)
pipe.add_component(Neo4jWriter(neo4j_driver), "writer")
# define the execution order of component
# and how the output of previous components must be used
pipe.connect(
"splitter",
"lexical_graph_builder",
input_config={"text_chunks": "splitter"},
)
pipe.connect(
"lexical_graph_builder",
"writer",
input_config={
"graph": "lexical_graph_builder.graph",
"lexical_graph_config": "lexical_graph_builder.config",
},
)
# user input:
# the initial text
# and the list of entities and relations we are looking for
pipe_inputs = {
"splitter": {
"text": """Albert Einstein was a German physicist born in 1879 who
wrote many groundbreaking papers especially about general relativity
and quantum mechanics. He worked for many different institutions, including
the University of Bern in Switzerland and the University of Oxford."""
},
"lexical_graph_builder": {
"document_info": {
# 'path' can be anything
"path": "example/lexical_graph_from_text.py"
},
},
}
# run the pipeline
async for step_result in pipe.run_step_by_step(pipe_inputs):
print(step_result)
await asyncio.sleep(2)


if __name__ == "__main__":
with neo4j.GraphDatabase.driver(
"bolt://localhost:7687", auth=("neo4j", "password")
) as driver:
print(asyncio.run(main(driver)))
70 changes: 52 additions & 18 deletions src/neo4j_graphrag/experimental/pipeline/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@
from __future__ import annotations

import asyncio
import enum
import logging
import uuid
import warnings
import datetime

from pydantic import BaseModel, Field
from typing import TYPE_CHECKING, Any, AsyncGenerator

from neo4j_graphrag.experimental.pipeline.exceptions import (
Expand All @@ -35,6 +39,20 @@
logger = logging.getLogger(__name__)


class ResultType(enum.Enum):
TASK_CHECKPOINT = "TASK_CHECKPOINT"
TASK_FINISHED = "TASK_FINISHED"
PIPELINE_FINISHED = "PIPELINE_FINISHED"


class Result(BaseModel):
result_type: ResultType
data: Any
timestamp: datetime.datetime = Field(
default_factory=lambda: datetime.datetime.now(datetime.timezone.utc)
)


class Orchestrator:
"""Orchestrate a pipeline.

Expand All @@ -53,17 +71,7 @@ def __init__(self, pipeline: "Pipeline"):
self.event_notifier = EventNotifier(pipeline.callback)
self.run_id = str(uuid.uuid4())

async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None:
"""Get inputs and run a specific task. Once the task is done,
calls the on_task_complete method.

Args:
task (TaskPipelineNode): The task to be run
data (dict[str, Any]): The pipeline input data

Returns:
None
"""
async def run_single_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> Result:
param_mapping = self.get_input_config_for_task(task)
inputs = await self.get_component_inputs(task.name, param_mapping, data)
try:
Expand All @@ -72,13 +80,31 @@ async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None:
logger.debug(
f"ORCHESTRATOR: TASK ABORTED: {task.name} is already running or done, aborting"
)
return None
raise StopAsyncIteration()
await self.event_notifier.notify_task_started(self.run_id, task.name, inputs)
res = await task.run(inputs)
await self.set_task_status(task.name, RunStatus.DONE)
await self.event_notifier.notify_task_finished(self.run_id, task.name, res)
if res:
await self.on_task_complete(data=data, task=task, result=res)
await self.save_results(task=task, result=res)
return Result(result_type=ResultType.TASK_FINISHED, data=res.result)

async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> AsyncGenerator[Result, None]:
"""Get inputs and run a specific task. Once the task is done,
calls the on_task_complete method.

Args:
task (TaskPipelineNode): The task to be run
data (dict[str, Any]): The pipeline input data

Returns:
None
"""
yield await self.run_single_task(task, data)
# then get the next tasks to be executed and run them
async for n in self.next(task):
async for res in self.run_task(n, data):
yield res

async def set_task_status(self, task_name: str, status: RunStatus) -> None:
"""Set a new status
Expand All @@ -102,8 +128,8 @@ async def set_task_status(self, task_name: str, status: RunStatus) -> None:
self.run_id, task_name, status.value
)

async def on_task_complete(
self, data: dict[str, Any], task: TaskPipelineNode, result: RunResult
async def save_results(
self, task: TaskPipelineNode, result: RunResult
) -> None:
"""When a given task is complete, it will call this method
to find the next tasks to run.
Expand All @@ -115,9 +141,6 @@ async def on_task_complete(
await self.add_result_for_component(
task.name, res_to_save, is_final=task.is_leaf()
)
# then get the next tasks to be executed
# and run them in //
await asyncio.gather(*[self.run_task(n, data) async for n in self.next(task)])

async def check_dependencies_complete(self, task: TaskPipelineNode) -> None:
"""Check that all parent tasks are complete.
Expand Down Expand Up @@ -257,3 +280,14 @@ async def run(self, data: dict[str, Any]) -> None:
await self.event_notifier.notify_pipeline_finished(
self.run_id, await self.pipeline.get_final_results(self.run_id)
)

async def run_step_by_step(self, data: dict[str, Any]) -> AsyncGenerator[Result, None]:
"""Run the pipline, starting from the root nodes
(node without any parent). Then the callback on_task_complete
will handle the task dependencies.
"""
for root in self.pipeline.roots():
async for res in self.run_task(root, data):
yield res
# tasks = [self.run_task(root, data) for root in self.pipeline.roots()]
# await asyncio.gather(*tasks)
20 changes: 19 additions & 1 deletion src/neo4j_graphrag/experimental/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import warnings
from collections import defaultdict
from timeit import default_timer
from typing import Any, Optional
from typing import Any, Optional, AsyncGenerator

from neo4j_graphrag.utils.logging import prettify

Expand Down Expand Up @@ -420,3 +420,21 @@ async def run(self, data: dict[str, Any]) -> PipelineResult:
run_id=orchestrator.run_id,
result=await self.get_final_results(orchestrator.run_id),
)

async def run_step_by_step(self, data: dict[str, Any]) -> AsyncGenerator[PipelineResult, None]:
logger.debug("PIPELINE START")
start_time = default_timer()
self.invalidate()
self.validate_input_data(data)
orchestrator = Orchestrator(self)
logger.debug(f"PIPELINE ORCHESTRATOR: {orchestrator.run_id}")
async for res in orchestrator.run_step_by_step(data):
yield res
end_time = default_timer()
logger.debug(
f"PIPELINE FINISHED {orchestrator.run_id} in {end_time - start_time}s"
)
yield PipelineResult(
run_id=orchestrator.run_id,
result=await self.get_final_results(orchestrator.run_id),
)
Loading