Skip to content

Pipeline state dump and load #352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 20 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

- Added support for automatic schema extraction from text using LLMs. In the `SimpleKGPipeline`, when the user provides no schema, the automatic schema extraction is enabled by default.
- Added ability to return a user-defined message if context is empty in GraphRAG (which skips the LLM call).
- Added pipeline state management with `run_until`, `resume_from`, `dump_state`, and `load_state` methods, enabling pipeline execution checkpointing and resumption.

### Fixed

Expand Down
30 changes: 30 additions & 0 deletions docs/source/user_guide_pipeline.rst
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,36 @@ can be added to the visualization by setting `hide_unused_outputs` to `False`:
webbrowser.open("pipeline_full.html")


*************************
Pipeline State Management
*************************

Pipelines support checkpointing and resumption through state management features:

.. code:: python

# Run pipeline until a specific component
state = await pipeline.run_until(data, stop_after="component_name", state_file="state.json")

# Resume pipeline from a specific component
result = await pipeline.resume_from(state, data, start_from="component_name")

# Alternatively, load state from file
result = await pipeline.resume_from(None, data, start_from="component_name", state_file="state.json")

The state contains:
- Pipeline configuration (parameter mappings between components and validation state)
- Execution results (outputs from completed components stored in the ResultStore)
- Final pipeline results from previous runs
- Component-specific state (interface available but not yet implemented by components)

This enables:
- Checkpointing long-running pipelines
- Debugging pipeline execution
- Resuming failed pipelines from the last successful component
- Comparing different component implementations with deterministic inputs by saving the state before the component and reusing it, avoiding non-deterministic results from preceding components


************************
Adding an Event Callback
************************
Expand Down
24 changes: 23 additions & 1 deletion src/neo4j_graphrag/experimental/pipeline/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

import inspect
from typing import Any, get_type_hints
from typing import Any, Dict, get_type_hints

from pydantic import BaseModel

Expand Down Expand Up @@ -108,3 +108,25 @@ async def run_with_context(
"""
# default behavior to prevent a breaking change
return await self.run(*args, **kwargs)

def serialize_state(self) -> Dict[str, Any]:
"""Serialize component state to a dictionary.

This method can be overridden by components to customize
their serialization behavior. By default, it returns an empty dict.

Returns:
Dict[str, Any]: Serialized state of the component
"""
return {}

def load_state(self, state: Dict[str, Any]) -> None:
"""Load component state from a serialized dictionary.

This method can be overridden by components to customize
their deserialization behavior. By default, it does nothing.

Args:
state (Dict[str, Any]): Previously serialized component state
"""
pass
29 changes: 25 additions & 4 deletions src/neo4j_graphrag/experimental/pipeline/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import uuid
import warnings
from functools import partial
from typing import TYPE_CHECKING, Any, AsyncGenerator
from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional

from neo4j_graphrag.experimental.pipeline.types.context import RunContext
from neo4j_graphrag.experimental.pipeline.exceptions import (
Expand All @@ -46,16 +46,29 @@ class Orchestrator:
- finding the next tasks to execute
- building the inputs for each task
- calling the run method on each task
- optionally stopping after a specified component
- optionally starting from a specified component

Once a TaskNode is done, it calls the `on_task_complete` callback
that will save the results, find the next tasks to be executed
(checking that all dependencies are met), and run them.

Partial execution is supported through:
- stop_after: Stop execution after this component completes
- start_from: Start execution from this component instead of roots
"""

def __init__(self, pipeline: Pipeline):
def __init__(
self,
pipeline: Pipeline,
stop_after: Optional[str] = None,
start_from: Optional[str] = None,
):
self.pipeline = pipeline
self.event_notifier = EventNotifier(pipeline.callbacks)
self.run_id = str(uuid.uuid4())
self.stop_after = stop_after
self.start_from = start_from

async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None:
"""Get inputs and run a specific task. Once the task is done,
Expand Down Expand Up @@ -129,7 +142,10 @@ async def on_task_complete(
await self.add_result_for_component(
task.name, res_to_save, is_final=task.is_leaf()
)
# then get the next tasks to be executed
# stop if this is the stop_after node
if self.stop_after and task.name == self.stop_after:
return
# otherwise, get the next tasks to be executed
# and run them in //
await asyncio.gather(*[self.run_task(n, data) async for n in self.next(task)])

Expand Down Expand Up @@ -266,7 +282,12 @@ async def run(self, data: dict[str, Any]) -> None:
will handle the task dependencies.
"""
await self.event_notifier.notify_pipeline_started(self.run_id, data)
tasks = [self.run_task(root, data) for root in self.pipeline.roots()]
# start from a specific node if requested, otherwise from roots
if self.start_from:
start_nodes = [self.pipeline.get_node_by_name(self.start_from)]
else:
start_nodes = self.pipeline.roots()
tasks = [self.run_task(root, data) for root in start_nodes]
await asyncio.gather(*tasks)
await self.event_notifier.notify_pipeline_finished(
self.run_id, await self.pipeline.get_final_results(self.run_id)
Expand Down
142 changes: 141 additions & 1 deletion src/neo4j_graphrag/experimental/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
import warnings
from collections import defaultdict
from timeit import default_timer
from typing import Any, Optional, AsyncGenerator
from typing import Any, Optional, AsyncGenerator, Dict
import asyncio
import json

from neo4j_graphrag.utils.logging import prettify

Expand Down Expand Up @@ -579,3 +580,142 @@ async def run(self, data: dict[str, Any]) -> PipelineResult:
run_id=orchestrator.run_id,
result=await self.get_final_results(orchestrator.run_id),
)

def dump_state(self, run_id: str) -> Dict[str, Any]:
"""Dump the current state of the pipeline and its components to a serializable dictionary.

Args:
run_id: The run_id that was used when the pipeline was executed

Returns:
dict[str, Any]: A serializable dictionary containing the pipeline state
"""
pipeline_state: Dict[str, Any] = {
"run_id": run_id,
"components": {},
"store": self.store.dump() if hasattr(self.store, "dump") else {},
"final_results": self.final_results.dump()
if hasattr(self.final_results, "dump")
else {},
"is_validated": self.is_validated,
"param_mapping": self.param_mapping,
"missing_inputs": self.missing_inputs,
}

components_dict: Dict[str, Any] = {}
pipeline_state["components"] = components_dict

# serialize each component's state
for name, node in self._nodes.items():
components_dict[name] = node.component.serialize_state()

return pipeline_state

def load_state(self, state: Dict[str, Any]) -> None:
"""Load pipeline state from a serialized dictionary.

Args:
state (dict[str, Any]): Previously serialized pipeline state
"""
# load component states
for name, component_state in state.get("components", {}).items():
if name in self._nodes:
self._nodes[name].component.load_state(component_state)

# load other pipeline state attributes
if "is_validated" in state:
self.is_validated = state["is_validated"]

if "param_mapping" in state:
self.param_mapping = state["param_mapping"]

if "missing_inputs" in state:
self.missing_inputs = state["missing_inputs"]

# load store data if store has load method
if "store" in state and hasattr(self.store, "load"):
self.store.load(state["store"])

# load final results if it has load method
if "final_results" in state and hasattr(self.final_results, "load"):
self.final_results.load(state["final_results"])

async def run_until(
self, data: Dict[str, Any], stop_after: str, state_file: Optional[str] = None
) -> Dict[str, Any]:
"""
Run the pipeline until a specific component and return the state.

Args:
data (Dict[str, Any]): The input data for the pipeline.
stop_after (str): The name of the component to stop after.
state_file (Optional[str]): If provided, save the state to this file as JSON.

Returns:
Dict[str, Any]: The serialized state of the pipeline after execution.
"""
logger.debug("PIPELINE START (RUN UNTIL)")
start_time = default_timer()
self.invalidate()
self.validate_input_data(data)
orchestrator = Orchestrator(self, stop_after=stop_after)
logger.debug(f"PIPELINE ORCHESTRATOR: {orchestrator.run_id}")
await orchestrator.run(data)
end_time = default_timer()
logger.debug(
f"PIPELINE FINISHED (RUN UNTIL) {orchestrator.run_id} in {end_time - start_time}s"
)
state = self.dump_state(orchestrator.run_id)
if state_file:
with open(state_file, "w", encoding="utf-8") as f:
json.dump(state, f, ensure_ascii=False, indent=2)
return state

async def resume_from(
self,
state: Optional[
Dict[str, Any]
], # Required but can be None if state_file is provided
data: Dict[str, Any],
start_from: str,
state_file: Optional[str] = None,
) -> PipelineResult:
"""
Resume pipeline execution from a specific component using a saved state.

Args:
state (Optional[Dict[str, Any]]): The serialized pipeline state. Required, but can be None if state_file is provided.
data (Dict[str, Any]): Additional input data for the pipeline.
start_from (str): The name of the component to start execution from.
state_file (Optional[str]): If provided, load the state from this file as JSON. Required if state is None.

Returns:
PipelineResult: The result of the pipeline execution.

Raises:
ValueError: If neither state nor state_file is provided.
"""
if state_file:
with open(state_file, "r", encoding="utf-8") as f:
state = json.load(f)
if state is None:
raise ValueError("No state provided for resume_from.")
self.load_state(state)
run_id = state.get("run_id")
if not run_id:
raise ValueError("No run_id found in state. Cannot resume execution.")
logger.debug("PIPELINE START (RESUME FROM)")
start_time = default_timer()
self.validate_input_data(data)
orchestrator = Orchestrator(self, start_from=start_from)
orchestrator.run_id = run_id # Use the original run_id
logger.debug(f"PIPELINE ORCHESTRATOR: {orchestrator.run_id}")
await orchestrator.run(data)
end_time = default_timer()
logger.debug(
f"PIPELINE FINISHED (RESUME FROM) {orchestrator.run_id} in {end_time - start_time}s"
)
return PipelineResult(
run_id=orchestrator.run_id,
result=await self.get_final_results(orchestrator.run_id),
)
16 changes: 16 additions & 0 deletions tests/unit/experimental/pipeline/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from typing import Dict, Any

from neo4j_graphrag.experimental.pipeline import Component, DataModel
from neo4j_graphrag.experimental.pipeline.types.context import RunContext
Expand Down Expand Up @@ -63,3 +64,18 @@ def __init__(self, sleep: float = 1.0) -> None:
async def run(self, number1: int, number2: int = 2) -> IntResultModel:
await asyncio.sleep(self.sleep)
return IntResultModel(result=number1 * number2)


class StatefulComponent(Component):
def __init__(self) -> None:
self.counter = 0

async def run(self, value: int) -> IntResultModel:
self.counter += value
return IntResultModel(result=self.counter)

def serialize_state(self) -> Dict[str, Any]:
return {"counter": self.counter}

def load_state(self, state: Dict[str, Any]) -> None:
self.counter = state.get("counter", 0)
17 changes: 16 additions & 1 deletion tests/unit/experimental/pipeline/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@

from neo4j_graphrag.experimental.pipeline import Component
from neo4j_graphrag.experimental.pipeline.types.context import RunContext
from .components import ComponentMultiply, ComponentMultiplyWithContext, IntResultModel
from .components import (
ComponentMultiply,
ComponentMultiplyWithContext,
IntResultModel,
StatefulComponent,
)


def test_component_inputs() -> None:
Expand Down Expand Up @@ -87,3 +92,13 @@ class WrongComponent(Component):
"You must implement either `run` or `run_with_context` in Component 'WrongComponent'"
in str(e)
)


def test_stateful_component_serialize_and_load_state() -> None:
c = StatefulComponent()
c.counter = 42
state = c.serialize_state()
assert state == {"counter": 42}
c.counter = 0
c.load_state({"counter": 99})
assert c.counter == 99
Loading