From 4a93761cff15fc2498af8d37f7e999d5d74df7b1 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Thu, 15 May 2025 11:36:19 +0200 Subject: [PATCH 01/20] Serialize/deserialize component state --- .../experimental/pipeline/component.py | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/component.py b/src/neo4j_graphrag/experimental/pipeline/component.py index 39a2816ef..69289feb4 100644 --- a/src/neo4j_graphrag/experimental/pipeline/component.py +++ b/src/neo4j_graphrag/experimental/pipeline/component.py @@ -15,7 +15,7 @@ from __future__ import annotations import inspect -from typing import Any, get_type_hints +from typing import Any, Dict, get_type_hints from pydantic import BaseModel @@ -108,3 +108,25 @@ async def run_with_context( """ # default behavior to prevent a breaking change return await self.run(*args, **kwargs) + + def serialize_state(self) -> Dict[str, Any]: + """Serialize component state to a dictionary. + + This method can be overridden by components to customize + their serialization behavior. By default, it returns an empty dict. + + Returns: + Dict[str, Any]: Serialized state of the component + """ + return {} + + def load_state(self, state: Dict[str, Any]) -> None: + """Load component state from a serialized dictionary. + + This method can be overridden by components to customize + their deserialization behavior. By default, it does nothing. + + Args: + state (Dict[str, Any]): Previously serialized component state + """ + pass From f94050627ce6e34b19d9a3b09cd90cc8f058d503 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Thu, 15 May 2025 11:39:00 +0200 Subject: [PATCH 02/20] Run pipeline until/Resume pipeline from --- .../experimental/pipeline/orchestrator.py | 29 +++++- .../experimental/pipeline/pipeline.py | 93 ++++++++++++++++++- .../experimental/pipeline/stores.py | 35 ++++++- 3 files changed, 151 insertions(+), 6 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py index de9468bf7..550b99ecf 100644 --- a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py +++ b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py @@ -19,7 +19,7 @@ import uuid import warnings from functools import partial -from typing import TYPE_CHECKING, Any, AsyncGenerator +from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional from neo4j_graphrag.experimental.pipeline.types.context import RunContext from neo4j_graphrag.experimental.pipeline.exceptions import ( @@ -46,16 +46,29 @@ class Orchestrator: - finding the next tasks to execute - building the inputs for each task - calling the run method on each task + - optionally stopping after a specified component + - optionally starting from a specified component Once a TaskNode is done, it calls the `on_task_complete` callback that will save the results, find the next tasks to be executed (checking that all dependencies are met), and run them. + + Partial execution is supported through: + - stop_after: Stop execution after this component completes + - start_from: Start execution from this component instead of roots """ - def __init__(self, pipeline: Pipeline): + def __init__( + self, + pipeline: Pipeline, + stop_after: Optional[str] = None, + start_from: Optional[str] = None, + ): self.pipeline = pipeline self.event_notifier = EventNotifier(pipeline.callbacks) self.run_id = str(uuid.uuid4()) + self.stop_after = stop_after + self.start_from = start_from async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None: """Get inputs and run a specific task. Once the task is done, @@ -129,7 +142,10 @@ async def on_task_complete( await self.add_result_for_component( task.name, res_to_save, is_final=task.is_leaf() ) - # then get the next tasks to be executed + # stop if this is the stop_after node + if self.stop_after and task.name == self.stop_after: + return + # otherwise, get the next tasks to be executed # and run them in // await asyncio.gather(*[self.run_task(n, data) async for n in self.next(task)]) @@ -266,7 +282,12 @@ async def run(self, data: dict[str, Any]) -> None: will handle the task dependencies. """ await self.event_notifier.notify_pipeline_started(self.run_id, data) - tasks = [self.run_task(root, data) for root in self.pipeline.roots()] + # start from a specific node if requested, otherwise from roots + if self.start_from: + start_nodes = [self.pipeline.get_node_by_name(self.start_from)] + else: + start_nodes = self.pipeline.roots() + tasks = [self.run_task(root, data) for root in start_nodes] await asyncio.gather(*tasks) await self.event_notifier.notify_pipeline_finished( self.run_id, await self.pipeline.get_final_results(self.run_id) diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline.py b/src/neo4j_graphrag/experimental/pipeline/pipeline.py index 91dab34ee..23b27ee57 100644 --- a/src/neo4j_graphrag/experimental/pipeline/pipeline.py +++ b/src/neo4j_graphrag/experimental/pipeline/pipeline.py @@ -18,7 +18,7 @@ import warnings from collections import defaultdict from timeit import default_timer -from typing import Any, Optional, AsyncGenerator +from typing import Any, Optional, AsyncGenerator, Dict import asyncio from neo4j_graphrag.utils.logging import prettify @@ -579,3 +579,94 @@ async def run(self, data: dict[str, Any]) -> PipelineResult: run_id=orchestrator.run_id, result=await self.get_final_results(orchestrator.run_id), ) + + def dump_state(self) -> Dict[str, Any]: + """Dump the current state of the pipeline and its components + to a serializable dictionary. + + Returns: + dict[str, Any]: A serializable dictionary containing the pipeline state + """ + pipeline_state = { + "components": {}, + "store": self.store.dump() if hasattr(self.store, "dump") else {}, + "final_results": self.final_results.dump() + if hasattr(self.final_results, "dump") + else {}, + "is_validated": self.is_validated, + "param_mapping": self.param_mapping, + "missing_inputs": self.missing_inputs, + } + + components_dict: Dict[str, Any] = {} + pipeline_state["components"] = components_dict + + # serialize each component's state + for name, node in self._nodes.items(): + components_dict[name] = node.component.serialize_state() + + return pipeline_state + + def load_state(self, state: Dict[str, Any]) -> None: + """Load pipeline state from a serialized dictionary. + + Args: + state (dict[str, Any]): Previously serialized pipeline state + """ + # load component states + for name, component_state in state.get("components", {}).items(): + if name in self._nodes: + self._nodes[name].component.load_state(component_state) + + # load other pipeline state attributes + if "is_validated" in state: + self.is_validated = state["is_validated"] + + if "param_mapping" in state: + self.param_mapping = state["param_mapping"] + + if "missing_inputs" in state: + self.missing_inputs = state["missing_inputs"] + + # load store data if store has load method + if "store" in state and hasattr(self.store, "load"): + self.store.load(state["store"]) + + # load final results if it has load method + if "final_results" in state and hasattr(self.final_results, "load"): + self.final_results.load(state["final_results"]) + + async def run_until(self, data: Dict[str, Any], stop_after: str) -> Dict[str, Any]: + """Run the pipeline until a specific component and return the state.""" + logger.debug("PIPELINE START (RUN UNTIL)") + start_time = default_timer() + self.invalidate() + self.validate_input_data(data) + orchestrator = Orchestrator(self, stop_after=stop_after) + logger.debug(f"PIPELINE ORCHESTRATOR: {orchestrator.run_id}") + await orchestrator.run(data) + end_time = default_timer() + logger.debug( + f"PIPELINE FINISHED (RUN UNTIL) {orchestrator.run_id} in {end_time - start_time}s" + ) + return self.dump_state() + + async def resume_from( + self, state: Dict[str, Any], data: Dict[str, Any], start_from: str + ) -> "PipelineResult": + """Resume pipeline execution from a specific component using a saved state.""" + self.load_state(state) + logger.debug("PIPELINE START (RESUME FROM)") + start_time = default_timer() + self.validate_input_data(data) + orchestrator = Orchestrator(self, start_from=start_from) + logger.debug(f"PIPELINE ORCHESTRATOR: {orchestrator.run_id}") + await orchestrator.run(data) + end_time = default_timer() + logger.debug( + f"PIPELINE FINISHED (RESUME FROM) {orchestrator.run_id} in {end_time - start_time}s" + ) + return PipelineResult( + run_id=orchestrator.run_id, + result=await self.get_final_results(orchestrator.run_id), + ) diff --git a/src/neo4j_graphrag/experimental/pipeline/stores.py b/src/neo4j_graphrag/experimental/pipeline/stores.py index 855546d6d..b5bec7bd7 100644 --- a/src/neo4j_graphrag/experimental/pipeline/stores.py +++ b/src/neo4j_graphrag/experimental/pipeline/stores.py @@ -20,7 +20,7 @@ import abc import asyncio -from typing import Any +from typing import Any, Dict class Store(abc.ABC): @@ -90,6 +90,22 @@ async def add_result_for_component( async def get_result_for_component(self, run_id: str, task_name: str) -> Any: return await self.get(self.get_key(run_id, task_name)) + def dump(self) -> Dict[str, Any]: + """Dump the store data to a serializable dictionary. + + Returns: + dict[str, Any]: A serializable representation of the store data + """ + raise NotImplementedError("Subclasses must implement this method") + + def load(self, data: Dict[str, Any]) -> None: + """Load store data from a serialized dictionary. + + Args: + data (dict[str, Any]): Previously serialized store data + """ + raise NotImplementedError("Subclasses must implement this method") + class InMemoryStore(ResultStore): """Simple in-memory store. @@ -115,3 +131,20 @@ def all(self) -> dict[str, Any]: def empty(self) -> None: self._data = {} + + def dump(self) -> Dict[str, Any]: + """Dump the in-memory store data to a serializable dictionary. + + Returns: + dict[str, Any]: A serializable representation of the store data + """ + return {"data": self._data.copy()} + + def load(self, data: Dict[str, Any]) -> None: + """Load in-memory store data from a serialized dictionary. + + Args: + data (dict[str, Any]): Previously serialized store data + """ + if "data" in data: + self._data = data["data"].copy() From 12d2e26adde0b6c3d77bad4765bc48649549744e Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Wed, 4 Jun 2025 10:50:46 +0200 Subject: [PATCH 03/20] Remove in memory storage support for pipeline state --- .../experimental/pipeline/stores.py | 35 +------------------ 1 file changed, 1 insertion(+), 34 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/stores.py b/src/neo4j_graphrag/experimental/pipeline/stores.py index b5bec7bd7..855546d6d 100644 --- a/src/neo4j_graphrag/experimental/pipeline/stores.py +++ b/src/neo4j_graphrag/experimental/pipeline/stores.py @@ -20,7 +20,7 @@ import abc import asyncio -from typing import Any, Dict +from typing import Any class Store(abc.ABC): @@ -90,22 +90,6 @@ async def add_result_for_component( async def get_result_for_component(self, run_id: str, task_name: str) -> Any: return await self.get(self.get_key(run_id, task_name)) - def dump(self) -> Dict[str, Any]: - """Dump the store data to a serializable dictionary. - - Returns: - dict[str, Any]: A serializable representation of the store data - """ - raise NotImplementedError("Subclasses must implement this method") - - def load(self, data: Dict[str, Any]) -> None: - """Load store data from a serialized dictionary. - - Args: - data (dict[str, Any]): Previously serialized store data - """ - raise NotImplementedError("Subclasses must implement this method") - class InMemoryStore(ResultStore): """Simple in-memory store. @@ -131,20 +115,3 @@ def all(self) -> dict[str, Any]: def empty(self) -> None: self._data = {} - - def dump(self) -> Dict[str, Any]: - """Dump the in-memory store data to a serializable dictionary. - - Returns: - dict[str, Any]: A serializable representation of the store data - """ - return {"data": self._data.copy()} - - def load(self, data: Dict[str, Any]) -> None: - """Load in-memory store data from a serialized dictionary. - - Args: - data (dict[str, Any]): Previously serialized store data - """ - if "data" in data: - self._data = data["data"].copy() From 8a00015ddf6fff53b44d7ea8d35c2bc91b63371f Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Wed, 4 Jun 2025 11:15:30 +0200 Subject: [PATCH 04/20] Add pipeline run_id and ability to save and load state from json file --- .../experimental/pipeline/pipeline.py | 71 ++++++++++++++++--- 1 file changed, 60 insertions(+), 11 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline.py b/src/neo4j_graphrag/experimental/pipeline/pipeline.py index 23b27ee57..d8ca842b6 100644 --- a/src/neo4j_graphrag/experimental/pipeline/pipeline.py +++ b/src/neo4j_graphrag/experimental/pipeline/pipeline.py @@ -18,8 +18,9 @@ import warnings from collections import defaultdict from timeit import default_timer -from typing import Any, Optional, AsyncGenerator, Dict +from typing import Any, Optional, AsyncGenerator, Dict, Union import asyncio +import json from neo4j_graphrag.utils.logging import prettify @@ -580,14 +581,17 @@ async def run(self, data: dict[str, Any]) -> PipelineResult: result=await self.get_final_results(orchestrator.run_id), ) - def dump_state(self) -> Dict[str, Any]: - """Dump the current state of the pipeline and its components - to a serializable dictionary. + def dump_state(self, run_id: str) -> Dict[str, Any]: + """Dump the current state of the pipeline and its components to a serializable dictionary. + + Args: + run_id: The run_id that was used when the pipeline was executed Returns: dict[str, Any]: A serializable dictionary containing the pipeline state """ - pipeline_state = { + pipeline_state: Dict[str, Any] = { + "run_id": run_id, "components": {}, "store": self.store.dump() if hasattr(self.store, "dump") else {}, "final_results": self.final_results.dump() @@ -636,8 +640,20 @@ def load_state(self, state: Dict[str, Any]) -> None: if "final_results" in state and hasattr(self.final_results, "load"): self.final_results.load(state["final_results"]) - async def run_until(self, data: Dict[str, Any], stop_after: str) -> Dict[str, Any]: - """Run the pipeline until a specific component and return the state.""" + async def run_until( + self, data: Dict[str, Any], stop_after: str, state_file: Optional[str] = None + ) -> Dict[str, Any]: + """ + Run the pipeline until a specific component and return the state. + + Args: + data (Dict[str, Any]): The input data for the pipeline. + stop_after (str): The name of the component to stop after. + state_file (Optional[str]): If provided, save the state to this file as JSON. + + Returns: + Dict[str, Any]: The serialized state of the pipeline after execution. + """ logger.debug("PIPELINE START (RUN UNTIL)") start_time = default_timer() self.invalidate() @@ -649,17 +665,50 @@ async def run_until(self, data: Dict[str, Any], stop_after: str) -> Dict[str, An logger.debug( f"PIPELINE FINISHED (RUN UNTIL) {orchestrator.run_id} in {end_time - start_time}s" ) - return self.dump_state() + state = self.dump_state(orchestrator.run_id) + if state_file: + with open(state_file, "w", encoding="utf-8") as f: + json.dump(state, f, ensure_ascii=False, indent=2) + return state async def resume_from( - self, state: Dict[str, Any], data: Dict[str, Any], start_from: str - ) -> "PipelineResult": - """Resume pipeline execution from a specific component using a saved state.""" + self, + state: Optional[ + Dict[str, Any] + ], # Required but can be None if state_file is provided + data: Dict[str, Any], + start_from: str, + state_file: Optional[str] = None, + ) -> PipelineResult: + """ + Resume pipeline execution from a specific component using a saved state. + + Args: + state (Optional[Dict[str, Any]]): The serialized pipeline state. Required, but can be None if state_file is provided. + data (Dict[str, Any]): Additional input data for the pipeline. + start_from (str): The name of the component to start execution from. + state_file (Optional[str]): If provided, load the state from this file as JSON. Required if state is None. + + Returns: + PipelineResult: The result of the pipeline execution. + + Raises: + ValueError: If neither state nor state_file is provided. + """ + if state_file: + with open(state_file, "r", encoding="utf-8") as f: + state = json.load(f) + if state is None: + raise ValueError("No state provided for resume_from.") self.load_state(state) + run_id = state.get("run_id") + if not run_id: + raise ValueError("No run_id found in state. Cannot resume execution.") logger.debug("PIPELINE START (RESUME FROM)") start_time = default_timer() self.validate_input_data(data) orchestrator = Orchestrator(self, start_from=start_from) + orchestrator.run_id = run_id # Use the original run_id logger.debug(f"PIPELINE ORCHESTRATOR: {orchestrator.run_id}") await orchestrator.run(data) end_time = default_timer() From 5f0c8925bae3e9e7c170157213bbce2d09782a3a Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Wed, 4 Jun 2025 11:15:41 +0200 Subject: [PATCH 05/20] Add unit tests --- .../unit/experimental/pipeline/components.py | 16 +++ .../experimental/pipeline/test_component.py | 17 ++- .../experimental/pipeline/test_pipeline.py | 133 +++++++++++++++++- 3 files changed, 164 insertions(+), 2 deletions(-) diff --git a/tests/unit/experimental/pipeline/components.py b/tests/unit/experimental/pipeline/components.py index e1b2e1845..7fd052cd0 100644 --- a/tests/unit/experimental/pipeline/components.py +++ b/tests/unit/experimental/pipeline/components.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import asyncio +from typing import Dict, Any from neo4j_graphrag.experimental.pipeline import Component, DataModel from neo4j_graphrag.experimental.pipeline.types.context import RunContext @@ -63,3 +64,18 @@ def __init__(self, sleep: float = 1.0) -> None: async def run(self, number1: int, number2: int = 2) -> IntResultModel: await asyncio.sleep(self.sleep) return IntResultModel(result=number1 * number2) + + +class StatefulComponent(Component): + def __init__(self) -> None: + self.counter = 0 + + async def run(self, value: int) -> IntResultModel: + self.counter += value + return IntResultModel(result=self.counter) + + def serialize_state(self) -> Dict[str, Any]: + return {"counter": self.counter} + + def load_state(self, state: Dict[str, Any]) -> None: + self.counter = state.get("counter", 0) diff --git a/tests/unit/experimental/pipeline/test_component.py b/tests/unit/experimental/pipeline/test_component.py index ba39aebaf..3d77803b4 100644 --- a/tests/unit/experimental/pipeline/test_component.py +++ b/tests/unit/experimental/pipeline/test_component.py @@ -18,7 +18,12 @@ from neo4j_graphrag.experimental.pipeline import Component from neo4j_graphrag.experimental.pipeline.types.context import RunContext -from .components import ComponentMultiply, ComponentMultiplyWithContext, IntResultModel +from .components import ( + ComponentMultiply, + ComponentMultiplyWithContext, + IntResultModel, + StatefulComponent, +) def test_component_inputs() -> None: @@ -87,3 +92,13 @@ class WrongComponent(Component): "You must implement either `run` or `run_with_context` in Component 'WrongComponent'" in str(e) ) + + +def test_stateful_component_serialize_and_load_state() -> None: + c = StatefulComponent() + c.counter = 42 + state = c.serialize_state() + assert state == {"counter": 42} + c.counter = 0 + c.load_state({"counter": 99}) + assert c.counter == 99 diff --git a/tests/unit/experimental/pipeline/test_pipeline.py b/tests/unit/experimental/pipeline/test_pipeline.py index d37fc65fc..6f30d33ea 100644 --- a/tests/unit/experimental/pipeline/test_pipeline.py +++ b/tests/unit/experimental/pipeline/test_pipeline.py @@ -17,7 +17,7 @@ import asyncio import datetime import tempfile -from typing import Sized +from typing import Sized, Any, Dict from unittest import mock from unittest.mock import AsyncMock, call, patch @@ -38,6 +38,7 @@ ComponentMultiply, ComponentNoParam, ComponentPassThrough, + StatefulComponent, StringResultModel, SlowComponentMultiply, ) @@ -590,3 +591,133 @@ async def callback(event: Event) -> None: events.append(e) assert len(events) == 2 assert len(pipe.callbacks) == 1 + + +@pytest.mark.asyncio +async def test_pipeline_state_dump_and_load() -> None: + pipe = Pipeline() + c1 = StatefulComponent() + c2 = StatefulComponent() + pipe.add_component(c1, "a") + pipe.add_component(c2, "b") + pipe.connect("a", "b", {"value": "a.result"}) + result = await pipe.run({"a": {"value": 5}}) + c1.counter = 123 # simulate state change + state = pipe.dump_state(result.run_id) + c1.counter = 0 + pipe.load_state(state) + assert c1.counter == 123 + + +@pytest.fixture +def stateful_pipeline() -> tuple[Pipeline, StatefulComponent, StatefulComponent]: + """Fixture that creates a pipeline with two stateful components connected in sequence.""" + pipe = Pipeline() + c1 = StatefulComponent() + c2 = StatefulComponent() + pipe.add_component(c1, "a") + pipe.add_component(c2, "b") + pipe.connect("a", "b", {"value": "a.result"}) + return pipe, c1, c2 + + +@pytest.mark.asyncio +async def test_pipeline_run_until_and_resume_from( + stateful_pipeline: tuple[Pipeline, StatefulComponent, StatefulComponent], +) -> None: + pipe, _, c2 = stateful_pipeline + + # run until first component + state = await pipe.run_until({"a": {"value": 7}}, stop_after="a") + + # verify that the state contains the run_id + assert "run_id" in state + run_id = state["run_id"] + assert isinstance(run_id, str) + assert len(run_id) > 0 + + # change c2 to a new instance to test resume + c2_new = StatefulComponent() + pipe.set_component("b", c2_new) + + # resume from b + result = await pipe.resume_from(state, {"a": {"value": 5}}, start_from="b") + + # verify we're using the same run_id and correct value propagation + assert result.run_id == run_id + # c2_new should have received value=7 (from a.result) + assert result.result["b"]["result"] == 7 + + +@pytest.mark.asyncio +async def test_pipeline_run_until_with_state_file( + stateful_pipeline: tuple[Pipeline, StatefulComponent, StatefulComponent], +) -> None: + pipe, _, c2 = stateful_pipeline + + # Create a temporary file for state + with tempfile.NamedTemporaryFile(suffix=".json") as state_file: + # run until first component and save state to file + state = await pipe.run_until( + {"a": {"value": 7}}, stop_after="a", state_file=state_file.name + ) + + # verify that the state contains the run_id + assert "run_id" in state + run_id = state["run_id"] + assert isinstance(run_id, str) + assert len(run_id) > 0 + + # change c2 to a new instance to test resume + c2_new = StatefulComponent() + pipe.set_component("b", c2_new) + + # resume from b using state file + result = await pipe.resume_from( + None, {"a": {"value": 5}}, start_from="b", state_file=state_file.name + ) + + # verify we're using the same run_id and correct value propagation + assert result.run_id == run_id + # c2_new should have received value=7 (from a.result) + assert result.result["b"]["result"] == 7 + + +@pytest.mark.asyncio +async def test_pipeline_resume_from_missing_state() -> None: + pipe = Pipeline() + c1 = StatefulComponent() + pipe.add_component(c1, "a") + + with pytest.raises(ValueError, match="No state provided for resume_from."): + await pipe.resume_from(None, {"a": {"value": 5}}, start_from="a") + + +@pytest.mark.asyncio +async def test_pipeline_resume_from_invalid_state() -> None: + pipe = Pipeline() + c1 = StatefulComponent() + pipe.add_component(c1, "a") + + invalid_state: Dict[str, Dict[str, Any]] = { + "components": {}, + "store": {}, + "final_results": {}, + } # missing run_id + + with pytest.raises( + ValueError, match="No run_id found in state. Cannot resume execution." + ): + await pipe.resume_from(invalid_state, {"a": {"value": 5}}, start_from="a") + + +@pytest.mark.asyncio +async def test_pipeline_resume_from_nonexistent_state_file() -> None: + pipe = Pipeline() + c1 = StatefulComponent() + pipe.add_component(c1, "a") + + with pytest.raises(FileNotFoundError): + await pipe.resume_from( + None, {"a": {"value": 5}}, start_from="a", state_file="nonexistent.json" + ) From 22722b8e82d50da013930534fed7175508afde6c Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Wed, 4 Jun 2025 11:35:41 +0200 Subject: [PATCH 06/20] Update changelog and docs --- CHANGELOG.md | 1 + docs/source/user_guide_pipeline.rst | 30 +++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a8bad9486..ab5362af9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ - Added support for automatic schema extraction from text using LLMs. In the `SimpleKGPipeline`, when the user provides no schema, the automatic schema extraction is enabled by default. - Added ability to return a user-defined message if context is empty in GraphRAG (which skips the LLM call). +- Added pipeline state management with `run_until`, `resume_from`, `dump_state`, and `load_state` methods, enabling pipeline execution checkpointing and resumption. ### Fixed diff --git a/docs/source/user_guide_pipeline.rst b/docs/source/user_guide_pipeline.rst index 297db2a82..d64382060 100644 --- a/docs/source/user_guide_pipeline.rst +++ b/docs/source/user_guide_pipeline.rst @@ -133,6 +133,36 @@ can be added to the visualization by setting `hide_unused_outputs` to `False`: webbrowser.open("pipeline_full.html") +************************* +Pipeline State Management +************************* + +Pipelines support checkpointing and resumption through state management features: + +.. code:: python + + # Run pipeline until a specific component + state = await pipeline.run_until(data, stop_after="component_name", state_file="state.json") + + # Resume pipeline from a specific component + result = await pipeline.resume_from(state, data, start_from="component_name") + + # Alternatively, load state from file + result = await pipeline.resume_from(None, data, start_from="component_name", state_file="state.json") + +The state contains: +- Pipeline configuration (parameter mappings between components and validation state) +- Execution results (outputs from completed components stored in the ResultStore) +- Final pipeline results from previous runs +- Component-specific state (interface available but not yet implemented by components) + +This enables: +- Checkpointing long-running pipelines +- Debugging pipeline execution +- Resuming failed pipelines from the last successful component +- Comparing different component implementations with deterministic inputs by saving the state before the component and reusing it, avoiding non-deterministic results from preceding components + + ************************ Adding an Event Callback ************************ From 9a57046ea23403f09c0b148fb01e093918ce3971 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Wed, 4 Jun 2025 12:24:12 +0200 Subject: [PATCH 07/20] Ruff --- src/neo4j_graphrag/experimental/pipeline/pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline.py b/src/neo4j_graphrag/experimental/pipeline/pipeline.py index d8ca842b6..784d23d09 100644 --- a/src/neo4j_graphrag/experimental/pipeline/pipeline.py +++ b/src/neo4j_graphrag/experimental/pipeline/pipeline.py @@ -18,7 +18,7 @@ import warnings from collections import defaultdict from timeit import default_timer -from typing import Any, Optional, AsyncGenerator, Dict, Union +from typing import Any, Optional, AsyncGenerator, Dict import asyncio import json From d1f7389a8384eaf51d20ecb873fc51d8187dac24 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 10 Jun 2025 15:45:50 +0200 Subject: [PATCH 08/20] Remove state management for component --- .../experimental/pipeline/component.py | 22 ------------------- .../unit/experimental/pipeline/components.py | 16 -------------- 2 files changed, 38 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/component.py b/src/neo4j_graphrag/experimental/pipeline/component.py index 69289feb4..d8d988eda 100644 --- a/src/neo4j_graphrag/experimental/pipeline/component.py +++ b/src/neo4j_graphrag/experimental/pipeline/component.py @@ -108,25 +108,3 @@ async def run_with_context( """ # default behavior to prevent a breaking change return await self.run(*args, **kwargs) - - def serialize_state(self) -> Dict[str, Any]: - """Serialize component state to a dictionary. - - This method can be overridden by components to customize - their serialization behavior. By default, it returns an empty dict. - - Returns: - Dict[str, Any]: Serialized state of the component - """ - return {} - - def load_state(self, state: Dict[str, Any]) -> None: - """Load component state from a serialized dictionary. - - This method can be overridden by components to customize - their deserialization behavior. By default, it does nothing. - - Args: - state (Dict[str, Any]): Previously serialized component state - """ - pass diff --git a/tests/unit/experimental/pipeline/components.py b/tests/unit/experimental/pipeline/components.py index 7fd052cd0..e1b2e1845 100644 --- a/tests/unit/experimental/pipeline/components.py +++ b/tests/unit/experimental/pipeline/components.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import asyncio -from typing import Dict, Any from neo4j_graphrag.experimental.pipeline import Component, DataModel from neo4j_graphrag.experimental.pipeline.types.context import RunContext @@ -64,18 +63,3 @@ def __init__(self, sleep: float = 1.0) -> None: async def run(self, number1: int, number2: int = 2) -> IntResultModel: await asyncio.sleep(self.sleep) return IntResultModel(result=number1 * number2) - - -class StatefulComponent(Component): - def __init__(self) -> None: - self.counter = 0 - - async def run(self, value: int) -> IntResultModel: - self.counter += value - return IntResultModel(result=self.counter) - - def serialize_state(self) -> Dict[str, Any]: - return {"counter": self.counter} - - def load_state(self, state: Dict[str, Any]) -> None: - self.counter = state.get("counter", 0) From 8bc8ac00083debf67c9f92833b4614ab2808e567 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 10 Jun 2025 17:06:48 +0200 Subject: [PATCH 09/20] Remove resume_from and run_until and reuse existing run interface --- .../experimental/pipeline/pipeline.py | 147 ++++-------------- 1 file changed, 31 insertions(+), 116 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline.py b/src/neo4j_graphrag/experimental/pipeline/pipeline.py index 784d23d09..02b99b7b1 100644 --- a/src/neo4j_graphrag/experimental/pipeline/pipeline.py +++ b/src/neo4j_graphrag/experimental/pipeline/pipeline.py @@ -20,7 +20,6 @@ from timeit import default_timer from typing import Any, Optional, AsyncGenerator, Dict import asyncio -import json from neo4j_graphrag.utils.logging import prettify @@ -564,51 +563,58 @@ async def event_stream(event: Event) -> None: if event_queue_getter_task and not event_queue_getter_task.done(): event_queue_getter_task.cancel() - async def run(self, data: dict[str, Any]) -> PipelineResult: + async def run( + self, + data: dict[str, Any], + from_: Optional[str] = None, + until: Optional[str] = None, + ) -> PipelineResult: + """Run the pipeline, optionally from a specific component or until a specific component. + + Args: + data (dict[str, Any]): The input data for the pipeline + from_ (str | None, optional): If provided, start execution from this component. Defaults to None. + until (str | None, optional): If provided, stop execution after this component. Defaults to None. + + Returns: + PipelineResult: The result of the pipeline execution + """ logger.debug("PIPELINE START") start_time = default_timer() self.invalidate() self.validate_input_data(data) - orchestrator = Orchestrator(self) + + # create orchestrator with appropriate start_from and stop_after params + orchestrator = Orchestrator(self, stop_after=until, start_from=from_) + logger.debug(f"PIPELINE ORCHESTRATOR: {orchestrator.run_id}") await orchestrator.run(data) + end_time = default_timer() logger.debug( f"PIPELINE FINISHED {orchestrator.run_id} in {end_time - start_time}s" ) + return PipelineResult( run_id=orchestrator.run_id, result=await self.get_final_results(orchestrator.run_id), ) def dump_state(self, run_id: str) -> Dict[str, Any]: - """Dump the current state of the pipeline and its components to a serializable dictionary. + """Dump the current state of the pipeline to a serializable dictionary. Args: run_id: The run_id that was used when the pipeline was executed Returns: - dict[str, Any]: A serializable dictionary containing the pipeline state + Dict[str, Any]: A serializable dictionary containing the pipeline state """ pipeline_state: Dict[str, Any] = { "run_id": run_id, - "components": {}, - "store": self.store.dump() if hasattr(self.store, "dump") else {}, - "final_results": self.final_results.dump() - if hasattr(self.final_results, "dump") - else {}, + "store": self.store.dump(), + "final_results": self.final_results.dump(), "is_validated": self.is_validated, - "param_mapping": self.param_mapping, - "missing_inputs": self.missing_inputs, } - - components_dict: Dict[str, Any] = {} - pipeline_state["components"] = components_dict - - # serialize each component's state - for name, node in self._nodes.items(): - components_dict[name] = node.component.serialize_state() - return pipeline_state def load_state(self, state: Dict[str, Any]) -> None: @@ -617,105 +623,14 @@ def load_state(self, state: Dict[str, Any]) -> None: Args: state (dict[str, Any]): Previously serialized pipeline state """ - # load component states - for name, component_state in state.get("components", {}).items(): - if name in self._nodes: - self._nodes[name].component.load_state(component_state) - - # load other pipeline state attributes + # load pipeline state attributes if "is_validated" in state: self.is_validated = state["is_validated"] - if "param_mapping" in state: - self.param_mapping = state["param_mapping"] - - if "missing_inputs" in state: - self.missing_inputs = state["missing_inputs"] - - # load store data if store has load method - if "store" in state and hasattr(self.store, "load"): + # load store data + if "store" in state: self.store.load(state["store"]) - # load final results if it has load method - if "final_results" in state and hasattr(self.final_results, "load"): + # load final results + if "final_results" in state: self.final_results.load(state["final_results"]) - - async def run_until( - self, data: Dict[str, Any], stop_after: str, state_file: Optional[str] = None - ) -> Dict[str, Any]: - """ - Run the pipeline until a specific component and return the state. - - Args: - data (Dict[str, Any]): The input data for the pipeline. - stop_after (str): The name of the component to stop after. - state_file (Optional[str]): If provided, save the state to this file as JSON. - - Returns: - Dict[str, Any]: The serialized state of the pipeline after execution. - """ - logger.debug("PIPELINE START (RUN UNTIL)") - start_time = default_timer() - self.invalidate() - self.validate_input_data(data) - orchestrator = Orchestrator(self, stop_after=stop_after) - logger.debug(f"PIPELINE ORCHESTRATOR: {orchestrator.run_id}") - await orchestrator.run(data) - end_time = default_timer() - logger.debug( - f"PIPELINE FINISHED (RUN UNTIL) {orchestrator.run_id} in {end_time - start_time}s" - ) - state = self.dump_state(orchestrator.run_id) - if state_file: - with open(state_file, "w", encoding="utf-8") as f: - json.dump(state, f, ensure_ascii=False, indent=2) - return state - - async def resume_from( - self, - state: Optional[ - Dict[str, Any] - ], # Required but can be None if state_file is provided - data: Dict[str, Any], - start_from: str, - state_file: Optional[str] = None, - ) -> PipelineResult: - """ - Resume pipeline execution from a specific component using a saved state. - - Args: - state (Optional[Dict[str, Any]]): The serialized pipeline state. Required, but can be None if state_file is provided. - data (Dict[str, Any]): Additional input data for the pipeline. - start_from (str): The name of the component to start execution from. - state_file (Optional[str]): If provided, load the state from this file as JSON. Required if state is None. - - Returns: - PipelineResult: The result of the pipeline execution. - - Raises: - ValueError: If neither state nor state_file is provided. - """ - if state_file: - with open(state_file, "r", encoding="utf-8") as f: - state = json.load(f) - if state is None: - raise ValueError("No state provided for resume_from.") - self.load_state(state) - run_id = state.get("run_id") - if not run_id: - raise ValueError("No run_id found in state. Cannot resume execution.") - logger.debug("PIPELINE START (RESUME FROM)") - start_time = default_timer() - self.validate_input_data(data) - orchestrator = Orchestrator(self, start_from=start_from) - orchestrator.run_id = run_id # Use the original run_id - logger.debug(f"PIPELINE ORCHESTRATOR: {orchestrator.run_id}") - await orchestrator.run(data) - end_time = default_timer() - logger.debug( - f"PIPELINE FINISHED (RESUME FROM) {orchestrator.run_id} in {end_time - start_time}s" - ) - return PipelineResult( - run_id=orchestrator.run_id, - result=await self.get_final_results(orchestrator.run_id), - ) From f76e5ca594cbf533f07cfbdf4f41d428cef5b59f Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 10 Jun 2025 17:07:16 +0200 Subject: [PATCH 10/20] Add dump and load to InMemoryStore --- .../experimental/pipeline/stores.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/src/neo4j_graphrag/experimental/pipeline/stores.py b/src/neo4j_graphrag/experimental/pipeline/stores.py index 855546d6d..913fb4bdb 100644 --- a/src/neo4j_graphrag/experimental/pipeline/stores.py +++ b/src/neo4j_graphrag/experimental/pipeline/stores.py @@ -90,6 +90,24 @@ async def add_result_for_component( async def get_result_for_component(self, run_id: str, task_name: str) -> Any: return await self.get(self.get_key(run_id, task_name)) + @abc.abstractmethod + def dump(self) -> dict[str, Any]: + """Dump the store state to a serializable dictionary. + + Returns: + dict[str, Any]: A serializable dictionary containing the store state + """ + pass + + @abc.abstractmethod + def load(self, state: dict[str, Any]) -> None: + """Load the store state from a serializable dictionary. + + Args: + state (dict[str, Any]): A serializable dictionary containing the store state + """ + pass + class InMemoryStore(ResultStore): """Simple in-memory store. @@ -115,3 +133,19 @@ def all(self) -> dict[str, Any]: def empty(self) -> None: self._data = {} + + def dump(self) -> dict[str, Any]: + """Dump the store state to a serializable dictionary. + + Returns: + dict[str, Any]: A serializable dictionary containing the store state + """ + return self._data.copy() + + def load(self, state: dict[str, Any]) -> None: + """Load the store state from a serializable dictionary. + + Args: + state (dict[str, Any]): A serializable dictionary containing the store state + """ + self._data = state.copy() From 2cc5b99250f321d90b7c7b9aaafe64c167808b23 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 10 Jun 2025 17:47:16 +0200 Subject: [PATCH 11/20] Ruff --- src/neo4j_graphrag/experimental/pipeline/component.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/component.py b/src/neo4j_graphrag/experimental/pipeline/component.py index d8d988eda..39a2816ef 100644 --- a/src/neo4j_graphrag/experimental/pipeline/component.py +++ b/src/neo4j_graphrag/experimental/pipeline/component.py @@ -15,7 +15,7 @@ from __future__ import annotations import inspect -from typing import Any, Dict, get_type_hints +from typing import Any, get_type_hints from pydantic import BaseModel From 2a819ea9614837658cae299a770ff0c0b6ad82dd Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Wed, 11 Jun 2025 16:44:21 +0200 Subject: [PATCH 12/20] Add ability ro load and dump state by run_id --- .../experimental/pipeline/stores.py | 50 ++++++++++++++----- 1 file changed, 37 insertions(+), 13 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/stores.py b/src/neo4j_graphrag/experimental/pipeline/stores.py index 913fb4bdb..325f178c4 100644 --- a/src/neo4j_graphrag/experimental/pipeline/stores.py +++ b/src/neo4j_graphrag/experimental/pipeline/stores.py @@ -91,19 +91,23 @@ async def get_result_for_component(self, run_id: str, task_name: str) -> Any: return await self.get(self.get_key(run_id, task_name)) @abc.abstractmethod - def dump(self) -> dict[str, Any]: - """Dump the store state to a serializable dictionary. + def dump(self, run_id: str) -> dict[str, Any]: + """Dump the store state for a specific run_id to a serializable dictionary. + + Args: + run_id (str): The run_id to dump data for Returns: - dict[str, Any]: A serializable dictionary containing the store state + dict[str, Any]: A serializable dictionary containing the store state for the run_id """ pass @abc.abstractmethod - def load(self, state: dict[str, Any]) -> None: - """Load the store state from a serializable dictionary. + def load(self, run_id: str, state: dict[str, Any]) -> None: + """Load the store state for a specific run_id from a serializable dictionary. Args: + run_id (str): The run_id to load data for state (dict[str, Any]): A serializable dictionary containing the store state """ pass @@ -134,18 +138,38 @@ def all(self) -> dict[str, Any]: def empty(self) -> None: self._data = {} - def dump(self) -> dict[str, Any]: - """Dump the store state to a serializable dictionary. + def dump(self, run_id: str) -> dict[str, Any]: + """Dump the store state for a specific run_id to a serializable dictionary. + + Args: + run_id (str): The run_id to dump data for Returns: - dict[str, Any]: A serializable dictionary containing the store state + dict[str, Any]: A serializable dictionary containing the store state for the run_id """ - return self._data.copy() - - def load(self, state: dict[str, Any]) -> None: - """Load the store state from a serializable dictionary. + # filter data by run_id prefix + run_id_prefix = f"{run_id}:" + filtered_data = { + key: value + for key, value in self._data.items() + if key.startswith(run_id_prefix) + } + return filtered_data + + def load(self, run_id: str, state: dict[str, Any]) -> None: + """Load the store state for a specific run_id from a serializable dictionary. Args: + run_id (str): The run_id to load data for state (dict[str, Any]): A serializable dictionary containing the store state """ - self._data = state.copy() + # clear existing data for this run_id first + run_id_prefix = f"{run_id}:" + keys_to_remove = [ + key for key in self._data.keys() if key.startswith(run_id_prefix) + ] + for key in keys_to_remove: + del self._data[key] + + # load the new state data + self._data.update(state) From 472eaf105713a6c1d7d9240cffc30b1f64529c45 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Wed, 11 Jun 2025 17:11:41 +0200 Subject: [PATCH 13/20] Allow orchestrator to run use a run_id from previous run --- src/neo4j_graphrag/experimental/pipeline/orchestrator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py index 550b99ecf..0de454817 100644 --- a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py +++ b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py @@ -63,10 +63,11 @@ def __init__( pipeline: Pipeline, stop_after: Optional[str] = None, start_from: Optional[str] = None, + run_id: Optional[str] = None, ): self.pipeline = pipeline self.event_notifier = EventNotifier(pipeline.callbacks) - self.run_id = str(uuid.uuid4()) + self.run_id = run_id or str(uuid.uuid4()) self.stop_after = stop_after self.start_from = start_from From 13254ddf137ce2839e867f107912ead748b4a561 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Wed, 11 Jun 2025 17:15:33 +0200 Subject: [PATCH 14/20] Refactor pipeline and validate loaded state --- .../experimental/pipeline/pipeline.py | 134 ++++++++++++++++-- 1 file changed, 119 insertions(+), 15 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline.py b/src/neo4j_graphrag/experimental/pipeline/pipeline.py index 02b99b7b1..8b01ad29d 100644 --- a/src/neo4j_graphrag/experimental/pipeline/pipeline.py +++ b/src/neo4j_graphrag/experimental/pipeline/pipeline.py @@ -140,6 +140,7 @@ def __init__( } """ self.missing_inputs: dict[str, list[str]] = defaultdict() + self._current_run_id: Optional[str] = None @classmethod def from_template( @@ -390,7 +391,9 @@ def validate_parameter_mapping(self) -> None: self.validate_parameter_mapping_for_task(task) self.is_validated = True - def validate_input_data(self, data: dict[str, Any]) -> bool: + def validate_input_data( + self, data: dict[str, Any], from_: Optional[str] = None + ) -> bool: """Performs parameter and data validation before running the pipeline: - Check parameters defined in the connect method - Make sure the missing parameters are present in the input `data` dict. @@ -398,6 +401,8 @@ def validate_input_data(self, data: dict[str, Any]) -> bool: Args: data (dict[str, Any]): input data to use for validation (usually from Pipeline.run) + from_ (Optional[str]): If provided, only validate components that will actually execute + starting from this component Raises: PipelineDefinitionError if any parameter mapping is invalid or if a @@ -405,7 +410,15 @@ def validate_input_data(self, data: dict[str, Any]) -> bool: """ if not self.is_validated: self.validate_parameter_mapping() + + # determine which components need validation + components_to_validate = self._get_components_to_validate(from_) + for task in self._nodes.values(): + # skip validation for components that won't execute + if task.name not in components_to_validate: + continue + if task.name not in self.param_mapping: self.validate_parameter_mapping_for_task(task) missing_params = self.missing_inputs[task.name] @@ -417,6 +430,37 @@ def validate_input_data(self, data: dict[str, Any]) -> bool: ) return True + def _get_components_to_validate(self, from_: Optional[str] = None) -> set[str]: + """Determine which components need validation based on execution context. + + Args: + from_ (Optional[str]): Starting component for execution + + Returns: + set[str]: Set of component names that need validation + """ + if from_ is None: + # no from_ specified, validate all components + return set(self._nodes.keys()) + + # when from_ is specified, only validate components that will actually execute + # this includes the from_ component and all its downstream dependencies + components_to_validate = set() + + def add_downstream_components(component_name: str) -> None: + """Recursively add a component and all its downstream dependencies""" + if component_name in components_to_validate: + return # Already processed + components_to_validate.add(component_name) + + # add all components that depend on this one + for edge in self.next_edges(component_name): + add_downstream_components(edge.end) + + # start from the specified component and add all downstream + add_downstream_components(from_) + return components_to_validate + def validate_parameter_mapping_for_task(self, task: TaskPipelineNode) -> bool: """Make sure that all the parameter mapping for a given task are valid. Does not consider user input yet. @@ -582,10 +626,16 @@ async def run( logger.debug("PIPELINE START") start_time = default_timer() self.invalidate() - self.validate_input_data(data) + self.validate_input_data(data, from_) # create orchestrator with appropriate start_from and stop_after params - orchestrator = Orchestrator(self, stop_after=until, start_from=from_) + # if current run_id exists (from loaded state), use it to continue the same run + orchestrator = Orchestrator( + self, stop_after=until, start_from=from_, run_id=self._current_run_id + ) + + # Track the current run_id + self._current_run_id = orchestrator.run_id logger.debug(f"PIPELINE ORCHESTRATOR: {orchestrator.run_id}") await orchestrator.run(data) @@ -600,20 +650,23 @@ async def run( result=await self.get_final_results(orchestrator.run_id), ) - def dump_state(self, run_id: str) -> Dict[str, Any]: + def dump_state(self) -> Dict[str, Any]: """Dump the current state of the pipeline to a serializable dictionary. - Args: - run_id: The run_id that was used when the pipeline was executed - Returns: Dict[str, Any]: A serializable dictionary containing the pipeline state + + Raises: + ValueError: If no pipeline run has been executed yet """ + if self._current_run_id is None: + raise ValueError( + "No pipeline run has been executed yet. Cannot dump state without a run_id." + ) + pipeline_state: Dict[str, Any] = { - "run_id": run_id, - "store": self.store.dump(), - "final_results": self.final_results.dump(), - "is_validated": self.is_validated, + "run_id": self._current_run_id, + "store": self.store.dump(self._current_run_id), } return pipeline_state @@ -622,15 +675,66 @@ def load_state(self, state: Dict[str, Any]) -> None: Args: state (dict[str, Any]): Previously serialized pipeline state + + Raises: + ValueError: If the state is invalid or incompatible with current pipeline """ + if "run_id" not in state: + raise ValueError("Invalid state: missing run_id") + + run_id = state["run_id"] + + # validate pipeline compatibility + self._validate_state_compatibility(state) + + # set the current run_id + self._current_run_id = run_id + # load pipeline state attributes if "is_validated" in state: self.is_validated = state["is_validated"] # load store data if "store" in state: - self.store.load(state["store"]) + self.store.load(run_id, state["store"]) + + def _validate_state_compatibility(self, state: Dict[str, Any]) -> None: + """Validate that the loaded state is compatible with the current pipeline. + + This checks that the components defined in the pipeline match those + that were present when the state was saved. + + Args: + state (dict[str, Any]): The state to validate - # load final results - if "final_results" in state: - self.final_results.load(state["final_results"]) + Raises: + ValueError: If the state is incompatible with the current pipeline + """ + if "store" not in state: + return # no store data to validate + + store_data = state["store"] + if not store_data: + return # empty store, nothing to validate + + # extract component names from the store keys + # keys are in format: "run_id:component_name" or "run_id:component_name:suffix" + stored_components = set() + for key in store_data.keys(): + parts = key.split(":") + if len(parts) >= 2: + component_name = parts[1] + stored_components.add(component_name) + + # get current pipeline component names + current_components = set(self._nodes.keys()) + + # check if stored components are a subset of current components + # this allows for the pipeline to have additional components, but not missing ones + missing_components = stored_components - current_components + if missing_components: + raise ValueError( + f"State is incompatible with current pipeline. " + f"Missing components: {sorted(missing_components)}. " + f"Current pipeline components: {sorted(current_components)}" + ) From a9413d37ef958d358d939eb513dddab99bb5c08b Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Fri, 13 Jun 2025 12:45:40 +0200 Subject: [PATCH 15/20] Refactor pipeline run_id management --- .../experimental/pipeline/pipeline.py | 35 ++++++------------- 1 file changed, 11 insertions(+), 24 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline.py b/src/neo4j_graphrag/experimental/pipeline/pipeline.py index 8b01ad29d..04f7903a5 100644 --- a/src/neo4j_graphrag/experimental/pipeline/pipeline.py +++ b/src/neo4j_graphrag/experimental/pipeline/pipeline.py @@ -140,7 +140,6 @@ def __init__( } """ self.missing_inputs: dict[str, list[str]] = defaultdict() - self._current_run_id: Optional[str] = None @classmethod def from_template( @@ -440,7 +439,7 @@ def _get_components_to_validate(self, from_: Optional[str] = None) -> set[str]: set[str]: Set of component names that need validation """ if from_ is None: - # no from_ specified, validate all components + # no from_ specified, validate all components return set(self._nodes.keys()) # when from_ is specified, only validate components that will actually execute @@ -629,13 +628,7 @@ async def run( self.validate_input_data(data, from_) # create orchestrator with appropriate start_from and stop_after params - # if current run_id exists (from loaded state), use it to continue the same run - orchestrator = Orchestrator( - self, stop_after=until, start_from=from_, run_id=self._current_run_id - ) - - # Track the current run_id - self._current_run_id = orchestrator.run_id + orchestrator = Orchestrator(self, stop_after=until, start_from=from_) logger.debug(f"PIPELINE ORCHESTRATOR: {orchestrator.run_id}") await orchestrator.run(data) @@ -650,23 +643,24 @@ async def run( result=await self.get_final_results(orchestrator.run_id), ) - def dump_state(self) -> Dict[str, Any]: + def dump_state(self, run_id: str) -> Dict[str, Any]: """Dump the current state of the pipeline to a serializable dictionary. + Args: + run_id (str): The run_id to dump state for + Returns: Dict[str, Any]: A serializable dictionary containing the pipeline state Raises: - ValueError: If no pipeline run has been executed yet + ValueError: If run_id is None or empty """ - if self._current_run_id is None: - raise ValueError( - "No pipeline run has been executed yet. Cannot dump state without a run_id." - ) + if not run_id: + raise ValueError("run_id cannot be None or empty") pipeline_state: Dict[str, Any] = { - "run_id": self._current_run_id, - "store": self.store.dump(self._current_run_id), + "run_id": run_id, + "store": self.store.dump(run_id), } return pipeline_state @@ -687,13 +681,6 @@ def load_state(self, state: Dict[str, Any]) -> None: # validate pipeline compatibility self._validate_state_compatibility(state) - # set the current run_id - self._current_run_id = run_id - - # load pipeline state attributes - if "is_validated" in state: - self.is_validated = state["is_validated"] - # load store data if "store" in state: self.store.load(run_id, state["store"]) From 64bdc66280107555da7b86f73610a3555a6fc1f0 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Fri, 13 Jun 2025 12:46:10 +0200 Subject: [PATCH 16/20] Ensure previous run_ids are kept in store --- src/neo4j_graphrag/experimental/pipeline/stores.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/stores.py b/src/neo4j_graphrag/experimental/pipeline/stores.py index 325f178c4..da783aed3 100644 --- a/src/neo4j_graphrag/experimental/pipeline/stores.py +++ b/src/neo4j_graphrag/experimental/pipeline/stores.py @@ -163,13 +163,6 @@ def load(self, run_id: str, state: dict[str, Any]) -> None: run_id (str): The run_id to load data for state (dict[str, Any]): A serializable dictionary containing the store state """ - # clear existing data for this run_id first - run_id_prefix = f"{run_id}:" - keys_to_remove = [ - key for key in self._data.keys() if key.startswith(run_id_prefix) - ] - for key in keys_to_remove: - del self._data[key] - - # load the new state data + # add/update data without clearing - safer for concurrent access + # multiple pipelines can load the same state without interfering with each other self._data.update(state) From d020a7eaac35c8a0342ccae499885a1183ffad8e Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Fri, 13 Jun 2025 12:46:41 +0200 Subject: [PATCH 17/20] Ensure resume run with different run_ids --- src/neo4j_graphrag/experimental/pipeline/orchestrator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py index 0de454817..65b1ec8ce 100644 --- a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py +++ b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py @@ -63,11 +63,10 @@ def __init__( pipeline: Pipeline, stop_after: Optional[str] = None, start_from: Optional[str] = None, - run_id: Optional[str] = None, ): self.pipeline = pipeline self.event_notifier = EventNotifier(pipeline.callbacks) - self.run_id = run_id or str(uuid.uuid4()) + self.run_id = str(uuid.uuid4()) self.stop_after = stop_after self.start_from = start_from @@ -288,6 +287,7 @@ async def run(self, data: dict[str, Any]) -> None: start_nodes = [self.pipeline.get_node_by_name(self.start_from)] else: start_nodes = self.pipeline.roots() + tasks = [self.run_task(root, data) for root in start_nodes] await asyncio.gather(*tasks) await self.event_notifier.notify_pipeline_finished( From c97bb975fc8021e74fc155c495ca8e84b77b68de Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Mon, 16 Jun 2025 15:09:20 +0200 Subject: [PATCH 18/20] Cleanup stores --- src/neo4j_graphrag/experimental/pipeline/stores.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/stores.py b/src/neo4j_graphrag/experimental/pipeline/stores.py index da783aed3..8d5d9a1dc 100644 --- a/src/neo4j_graphrag/experimental/pipeline/stores.py +++ b/src/neo4j_graphrag/experimental/pipeline/stores.py @@ -103,11 +103,10 @@ def dump(self, run_id: str) -> dict[str, Any]: pass @abc.abstractmethod - def load(self, run_id: str, state: dict[str, Any]) -> None: - """Load the store state for a specific run_id from a serializable dictionary. + def load(self, state: dict[str, Any]) -> None: + """Load the store state from a serializable dictionary. Args: - run_id (str): The run_id to load data for state (dict[str, Any]): A serializable dictionary containing the store state """ pass @@ -156,11 +155,10 @@ def dump(self, run_id: str) -> dict[str, Any]: } return filtered_data - def load(self, run_id: str, state: dict[str, Any]) -> None: - """Load the store state for a specific run_id from a serializable dictionary. + def load(self, state: dict[str, Any]) -> None: + """Load the store state from a serializable dictionary. Args: - run_id (str): The run_id to load data for state (dict[str, Any]): A serializable dictionary containing the store state """ # add/update data without clearing - safer for concurrent access From 74e7db01e800aadc1e9eb52d1de19a202e497980 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Mon, 16 Jun 2025 15:10:07 +0200 Subject: [PATCH 19/20] Ensure proper handling of previous run_ids --- .../experimental/pipeline/orchestrator.py | 18 +++++++++++++++++- .../experimental/pipeline/pipeline.py | 17 +++++++++++++---- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py index 65b1ec8ce..9dc5c97e6 100644 --- a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py +++ b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py @@ -63,10 +63,12 @@ def __init__( pipeline: Pipeline, stop_after: Optional[str] = None, start_from: Optional[str] = None, + previous_run_id: Optional[str] = None, ): self.pipeline = pipeline self.event_notifier = EventNotifier(pipeline.callbacks) self.run_id = str(uuid.uuid4()) + self.previous_run_id = previous_run_id # useful for pipeline resumption self.stop_after = stop_after self.start_from = start_from @@ -268,10 +270,24 @@ async def add_result_for_component( ) async def get_results_for_component(self, name: str) -> Any: + # when resuming, check previous run_id, otherwise check current run_id + if self.previous_run_id: + return await self.pipeline.store.get_result_for_component( + self.previous_run_id, name + ) return await self.pipeline.store.get_result_for_component(self.run_id, name) async def get_status_for_component(self, name: str) -> RunStatus: - status = await self.pipeline.store.get_status_for_component(self.run_id, name) + # when resuming, check previous run_id, otherwise check current run_id + if self.previous_run_id: + status = await self.pipeline.store.get_status_for_component( + self.previous_run_id, name + ) + else: + status = await self.pipeline.store.get_status_for_component( + self.run_id, name + ) + if status is None: return RunStatus.UNKNOWN return RunStatus(status) diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline.py b/src/neo4j_graphrag/experimental/pipeline/pipeline.py index 04f7903a5..5d9d43809 100644 --- a/src/neo4j_graphrag/experimental/pipeline/pipeline.py +++ b/src/neo4j_graphrag/experimental/pipeline/pipeline.py @@ -611,6 +611,7 @@ async def run( data: dict[str, Any], from_: Optional[str] = None, until: Optional[str] = None, + previous_run_id: Optional[str] = None, ) -> PipelineResult: """Run the pipeline, optionally from a specific component or until a specific component. @@ -618,6 +619,7 @@ async def run( data (dict[str, Any]): The input data for the pipeline from_ (str | None, optional): If provided, start execution from this component. Defaults to None. until (str | None, optional): If provided, stop execution after this component. Defaults to None. + previous_run_id (str | None, optional): If provided, resume from this previous run_id. Defaults to None. Returns: PipelineResult: The result of the pipeline execution @@ -628,7 +630,9 @@ async def run( self.validate_input_data(data, from_) # create orchestrator with appropriate start_from and stop_after params - orchestrator = Orchestrator(self, stop_after=until, start_from=from_) + orchestrator = Orchestrator( + self, stop_after=until, start_from=from_, previous_run_id=previous_run_id + ) logger.debug(f"PIPELINE ORCHESTRATOR: {orchestrator.run_id}") await orchestrator.run(data) @@ -664,26 +668,31 @@ def dump_state(self, run_id: str) -> Dict[str, Any]: } return pipeline_state - def load_state(self, state: Dict[str, Any]) -> None: + def load_state(self, state: Dict[str, Any]) -> str: """Load pipeline state from a serialized dictionary. Args: state (dict[str, Any]): Previously serialized pipeline state + Returns: + str: The run_id from the loaded state + Raises: ValueError: If the state is invalid or incompatible with current pipeline """ if "run_id" not in state: raise ValueError("Invalid state: missing run_id") - run_id = state["run_id"] + run_id: str = state["run_id"] # validate pipeline compatibility self._validate_state_compatibility(state) # load store data if "store" in state: - self.store.load(run_id, state["store"]) + self.store.load(state["store"]) + + return run_id def _validate_state_compatibility(self, state: Dict[str, Any]) -> None: """Validate that the loaded state is compatible with the current pipeline. From 0ca4c60121899240bb73f087a605d99f36d47a01 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Mon, 16 Jun 2025 16:05:22 +0200 Subject: [PATCH 20/20] Update changelog and docs --- CHANGELOG.md | 2 +- docs/source/user_guide_pipeline.rst | 60 ++++++++++++++--------------- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ab5362af9..695ab2a25 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,7 @@ - Added support for automatic schema extraction from text using LLMs. In the `SimpleKGPipeline`, when the user provides no schema, the automatic schema extraction is enabled by default. - Added ability to return a user-defined message if context is empty in GraphRAG (which skips the LLM call). -- Added pipeline state management with `run_until`, `resume_from`, `dump_state`, and `load_state` methods, enabling pipeline execution checkpointing and resumption. +- Added pipeline execution control with state management (`dump_state()` and `load_state()` methods) and partial execution support in the `run()` method (with `until` and `from_` parameters), enabling pipeline state dump and resumption of long-running pipelines, debugging workflows, and incremental processing. ### Fixed diff --git a/docs/source/user_guide_pipeline.rst b/docs/source/user_guide_pipeline.rst index d64382060..18ab8a745 100644 --- a/docs/source/user_guide_pipeline.rst +++ b/docs/source/user_guide_pipeline.rst @@ -133,36 +133,6 @@ can be added to the visualization by setting `hide_unused_outputs` to `False`: webbrowser.open("pipeline_full.html") -************************* -Pipeline State Management -************************* - -Pipelines support checkpointing and resumption through state management features: - -.. code:: python - - # Run pipeline until a specific component - state = await pipeline.run_until(data, stop_after="component_name", state_file="state.json") - - # Resume pipeline from a specific component - result = await pipeline.resume_from(state, data, start_from="component_name") - - # Alternatively, load state from file - result = await pipeline.resume_from(None, data, start_from="component_name", state_file="state.json") - -The state contains: -- Pipeline configuration (parameter mappings between components and validation state) -- Execution results (outputs from completed components stored in the ResultStore) -- Final pipeline results from previous runs -- Component-specific state (interface available but not yet implemented by components) - -This enables: -- Checkpointing long-running pipelines -- Debugging pipeline execution -- Resuming failed pipelines from the last successful component -- Comparing different component implementations with deterministic inputs by saving the state before the component and reusing it, avoiding non-deterministic results from preceding components - - ************************ Adding an Event Callback ************************ @@ -234,3 +204,33 @@ This will send an `TASK_PROGRESS` event to the pipeline callback. .. note:: In a future release, the `context_` parameter will be added to the `run` method. + + +************************* +Pipeline State Management +************************* + +Pipelines support state management to enable saving and restoring execution state, which is useful for debugging, resuming long-running pipelines, or incremental processing workflows. + +Saving and Loading State +======================== + +You can save the current state of a pipeline execution using the `dump_state()` method and restore it with `load_state()`. The pipeline also supports partial execution using the `until` and `from_` parameters: + +- **`until`**: Stop execution after a specific component completes +- **`from_`**: Start execution from a specific component instead of from the beginning + +.. code:: python + + # Run pipeline and save state + result = await pipeline.run(..., until="a") + state = pipeline.dump_state(result.run_id) + # The user could save the state to a JSON file + + # Resuming pipeline, could be from another run + loaded_run_id = pipeline.load_state(state) + new_result = await pipeline.run(..., from_="b", previous_run_id=loaded_run_id) + +.. warning:: State Compatibility + + When loading state, the current pipeline must have at least all the components that were present when the state was saved. Additional components are allowed, but missing components will cause a validation error.