diff --git a/CHANGELOG.md b/CHANGELOG.md index a8bad9486..695ab2a25 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ - Added support for automatic schema extraction from text using LLMs. In the `SimpleKGPipeline`, when the user provides no schema, the automatic schema extraction is enabled by default. - Added ability to return a user-defined message if context is empty in GraphRAG (which skips the LLM call). +- Added pipeline execution control with state management (`dump_state()` and `load_state()` methods) and partial execution support in the `run()` method (with `until` and `from_` parameters), enabling pipeline state dump and resumption of long-running pipelines, debugging workflows, and incremental processing. ### Fixed diff --git a/docs/source/user_guide_pipeline.rst b/docs/source/user_guide_pipeline.rst index 297db2a82..18ab8a745 100644 --- a/docs/source/user_guide_pipeline.rst +++ b/docs/source/user_guide_pipeline.rst @@ -204,3 +204,33 @@ This will send an `TASK_PROGRESS` event to the pipeline callback. .. note:: In a future release, the `context_` parameter will be added to the `run` method. + + +************************* +Pipeline State Management +************************* + +Pipelines support state management to enable saving and restoring execution state, which is useful for debugging, resuming long-running pipelines, or incremental processing workflows. + +Saving and Loading State +======================== + +You can save the current state of a pipeline execution using the `dump_state()` method and restore it with `load_state()`. The pipeline also supports partial execution using the `until` and `from_` parameters: + +- **`until`**: Stop execution after a specific component completes +- **`from_`**: Start execution from a specific component instead of from the beginning + +.. code:: python + + # Run pipeline and save state + result = await pipeline.run(..., until="a") + state = pipeline.dump_state(result.run_id) + # The user could save the state to a JSON file + + # Resuming pipeline, could be from another run + loaded_run_id = pipeline.load_state(state) + new_result = await pipeline.run(..., from_="b", previous_run_id=loaded_run_id) + +.. warning:: State Compatibility + + When loading state, the current pipeline must have at least all the components that were present when the state was saved. Additional components are allowed, but missing components will cause a validation error. diff --git a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py index de9468bf7..9dc5c97e6 100644 --- a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py +++ b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py @@ -19,7 +19,7 @@ import uuid import warnings from functools import partial -from typing import TYPE_CHECKING, Any, AsyncGenerator +from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional from neo4j_graphrag.experimental.pipeline.types.context import RunContext from neo4j_graphrag.experimental.pipeline.exceptions import ( @@ -46,16 +46,31 @@ class Orchestrator: - finding the next tasks to execute - building the inputs for each task - calling the run method on each task + - optionally stopping after a specified component + - optionally starting from a specified component Once a TaskNode is done, it calls the `on_task_complete` callback that will save the results, find the next tasks to be executed (checking that all dependencies are met), and run them. + + Partial execution is supported through: + - stop_after: Stop execution after this component completes + - start_from: Start execution from this component instead of roots """ - def __init__(self, pipeline: Pipeline): + def __init__( + self, + pipeline: Pipeline, + stop_after: Optional[str] = None, + start_from: Optional[str] = None, + previous_run_id: Optional[str] = None, + ): self.pipeline = pipeline self.event_notifier = EventNotifier(pipeline.callbacks) self.run_id = str(uuid.uuid4()) + self.previous_run_id = previous_run_id # useful for pipeline resumption + self.stop_after = stop_after + self.start_from = start_from async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None: """Get inputs and run a specific task. Once the task is done, @@ -129,7 +144,10 @@ async def on_task_complete( await self.add_result_for_component( task.name, res_to_save, is_final=task.is_leaf() ) - # then get the next tasks to be executed + # stop if this is the stop_after node + if self.stop_after and task.name == self.stop_after: + return + # otherwise, get the next tasks to be executed # and run them in // await asyncio.gather(*[self.run_task(n, data) async for n in self.next(task)]) @@ -252,10 +270,24 @@ async def add_result_for_component( ) async def get_results_for_component(self, name: str) -> Any: + # when resuming, check previous run_id, otherwise check current run_id + if self.previous_run_id: + return await self.pipeline.store.get_result_for_component( + self.previous_run_id, name + ) return await self.pipeline.store.get_result_for_component(self.run_id, name) async def get_status_for_component(self, name: str) -> RunStatus: - status = await self.pipeline.store.get_status_for_component(self.run_id, name) + # when resuming, check previous run_id, otherwise check current run_id + if self.previous_run_id: + status = await self.pipeline.store.get_status_for_component( + self.previous_run_id, name + ) + else: + status = await self.pipeline.store.get_status_for_component( + self.run_id, name + ) + if status is None: return RunStatus.UNKNOWN return RunStatus(status) @@ -266,7 +298,13 @@ async def run(self, data: dict[str, Any]) -> None: will handle the task dependencies. """ await self.event_notifier.notify_pipeline_started(self.run_id, data) - tasks = [self.run_task(root, data) for root in self.pipeline.roots()] + # start from a specific node if requested, otherwise from roots + if self.start_from: + start_nodes = [self.pipeline.get_node_by_name(self.start_from)] + else: + start_nodes = self.pipeline.roots() + + tasks = [self.run_task(root, data) for root in start_nodes] await asyncio.gather(*tasks) await self.event_notifier.notify_pipeline_finished( self.run_id, await self.pipeline.get_final_results(self.run_id) diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline.py b/src/neo4j_graphrag/experimental/pipeline/pipeline.py index 91dab34ee..5d9d43809 100644 --- a/src/neo4j_graphrag/experimental/pipeline/pipeline.py +++ b/src/neo4j_graphrag/experimental/pipeline/pipeline.py @@ -18,7 +18,7 @@ import warnings from collections import defaultdict from timeit import default_timer -from typing import Any, Optional, AsyncGenerator +from typing import Any, Optional, AsyncGenerator, Dict import asyncio from neo4j_graphrag.utils.logging import prettify @@ -390,7 +390,9 @@ def validate_parameter_mapping(self) -> None: self.validate_parameter_mapping_for_task(task) self.is_validated = True - def validate_input_data(self, data: dict[str, Any]) -> bool: + def validate_input_data( + self, data: dict[str, Any], from_: Optional[str] = None + ) -> bool: """Performs parameter and data validation before running the pipeline: - Check parameters defined in the connect method - Make sure the missing parameters are present in the input `data` dict. @@ -398,6 +400,8 @@ def validate_input_data(self, data: dict[str, Any]) -> bool: Args: data (dict[str, Any]): input data to use for validation (usually from Pipeline.run) + from_ (Optional[str]): If provided, only validate components that will actually execute + starting from this component Raises: PipelineDefinitionError if any parameter mapping is invalid or if a @@ -405,7 +409,15 @@ def validate_input_data(self, data: dict[str, Any]) -> bool: """ if not self.is_validated: self.validate_parameter_mapping() + + # determine which components need validation + components_to_validate = self._get_components_to_validate(from_) + for task in self._nodes.values(): + # skip validation for components that won't execute + if task.name not in components_to_validate: + continue + if task.name not in self.param_mapping: self.validate_parameter_mapping_for_task(task) missing_params = self.missing_inputs[task.name] @@ -417,6 +429,37 @@ def validate_input_data(self, data: dict[str, Any]) -> bool: ) return True + def _get_components_to_validate(self, from_: Optional[str] = None) -> set[str]: + """Determine which components need validation based on execution context. + + Args: + from_ (Optional[str]): Starting component for execution + + Returns: + set[str]: Set of component names that need validation + """ + if from_ is None: + # no from_ specified, validate all components + return set(self._nodes.keys()) + + # when from_ is specified, only validate components that will actually execute + # this includes the from_ component and all its downstream dependencies + components_to_validate = set() + + def add_downstream_components(component_name: str) -> None: + """Recursively add a component and all its downstream dependencies""" + if component_name in components_to_validate: + return # Already processed + components_to_validate.add(component_name) + + # add all components that depend on this one + for edge in self.next_edges(component_name): + add_downstream_components(edge.end) + + # start from the specified component and add all downstream + add_downstream_components(from_) + return components_to_validate + def validate_parameter_mapping_for_task(self, task: TaskPipelineNode) -> bool: """Make sure that all the parameter mapping for a given task are valid. Does not consider user input yet. @@ -563,19 +606,131 @@ async def event_stream(event: Event) -> None: if event_queue_getter_task and not event_queue_getter_task.done(): event_queue_getter_task.cancel() - async def run(self, data: dict[str, Any]) -> PipelineResult: + async def run( + self, + data: dict[str, Any], + from_: Optional[str] = None, + until: Optional[str] = None, + previous_run_id: Optional[str] = None, + ) -> PipelineResult: + """Run the pipeline, optionally from a specific component or until a specific component. + + Args: + data (dict[str, Any]): The input data for the pipeline + from_ (str | None, optional): If provided, start execution from this component. Defaults to None. + until (str | None, optional): If provided, stop execution after this component. Defaults to None. + previous_run_id (str | None, optional): If provided, resume from this previous run_id. Defaults to None. + + Returns: + PipelineResult: The result of the pipeline execution + """ logger.debug("PIPELINE START") start_time = default_timer() self.invalidate() - self.validate_input_data(data) - orchestrator = Orchestrator(self) + self.validate_input_data(data, from_) + + # create orchestrator with appropriate start_from and stop_after params + orchestrator = Orchestrator( + self, stop_after=until, start_from=from_, previous_run_id=previous_run_id + ) + logger.debug(f"PIPELINE ORCHESTRATOR: {orchestrator.run_id}") await orchestrator.run(data) + end_time = default_timer() logger.debug( f"PIPELINE FINISHED {orchestrator.run_id} in {end_time - start_time}s" ) + return PipelineResult( run_id=orchestrator.run_id, result=await self.get_final_results(orchestrator.run_id), ) + + def dump_state(self, run_id: str) -> Dict[str, Any]: + """Dump the current state of the pipeline to a serializable dictionary. + + Args: + run_id (str): The run_id to dump state for + + Returns: + Dict[str, Any]: A serializable dictionary containing the pipeline state + + Raises: + ValueError: If run_id is None or empty + """ + if not run_id: + raise ValueError("run_id cannot be None or empty") + + pipeline_state: Dict[str, Any] = { + "run_id": run_id, + "store": self.store.dump(run_id), + } + return pipeline_state + + def load_state(self, state: Dict[str, Any]) -> str: + """Load pipeline state from a serialized dictionary. + + Args: + state (dict[str, Any]): Previously serialized pipeline state + + Returns: + str: The run_id from the loaded state + + Raises: + ValueError: If the state is invalid or incompatible with current pipeline + """ + if "run_id" not in state: + raise ValueError("Invalid state: missing run_id") + + run_id: str = state["run_id"] + + # validate pipeline compatibility + self._validate_state_compatibility(state) + + # load store data + if "store" in state: + self.store.load(state["store"]) + + return run_id + + def _validate_state_compatibility(self, state: Dict[str, Any]) -> None: + """Validate that the loaded state is compatible with the current pipeline. + + This checks that the components defined in the pipeline match those + that were present when the state was saved. + + Args: + state (dict[str, Any]): The state to validate + + Raises: + ValueError: If the state is incompatible with the current pipeline + """ + if "store" not in state: + return # no store data to validate + + store_data = state["store"] + if not store_data: + return # empty store, nothing to validate + + # extract component names from the store keys + # keys are in format: "run_id:component_name" or "run_id:component_name:suffix" + stored_components = set() + for key in store_data.keys(): + parts = key.split(":") + if len(parts) >= 2: + component_name = parts[1] + stored_components.add(component_name) + + # get current pipeline component names + current_components = set(self._nodes.keys()) + + # check if stored components are a subset of current components + # this allows for the pipeline to have additional components, but not missing ones + missing_components = stored_components - current_components + if missing_components: + raise ValueError( + f"State is incompatible with current pipeline. " + f"Missing components: {sorted(missing_components)}. " + f"Current pipeline components: {sorted(current_components)}" + ) diff --git a/src/neo4j_graphrag/experimental/pipeline/stores.py b/src/neo4j_graphrag/experimental/pipeline/stores.py index 855546d6d..8d5d9a1dc 100644 --- a/src/neo4j_graphrag/experimental/pipeline/stores.py +++ b/src/neo4j_graphrag/experimental/pipeline/stores.py @@ -90,6 +90,27 @@ async def add_result_for_component( async def get_result_for_component(self, run_id: str, task_name: str) -> Any: return await self.get(self.get_key(run_id, task_name)) + @abc.abstractmethod + def dump(self, run_id: str) -> dict[str, Any]: + """Dump the store state for a specific run_id to a serializable dictionary. + + Args: + run_id (str): The run_id to dump data for + + Returns: + dict[str, Any]: A serializable dictionary containing the store state for the run_id + """ + pass + + @abc.abstractmethod + def load(self, state: dict[str, Any]) -> None: + """Load the store state from a serializable dictionary. + + Args: + state (dict[str, Any]): A serializable dictionary containing the store state + """ + pass + class InMemoryStore(ResultStore): """Simple in-memory store. @@ -115,3 +136,31 @@ def all(self) -> dict[str, Any]: def empty(self) -> None: self._data = {} + + def dump(self, run_id: str) -> dict[str, Any]: + """Dump the store state for a specific run_id to a serializable dictionary. + + Args: + run_id (str): The run_id to dump data for + + Returns: + dict[str, Any]: A serializable dictionary containing the store state for the run_id + """ + # filter data by run_id prefix + run_id_prefix = f"{run_id}:" + filtered_data = { + key: value + for key, value in self._data.items() + if key.startswith(run_id_prefix) + } + return filtered_data + + def load(self, state: dict[str, Any]) -> None: + """Load the store state from a serializable dictionary. + + Args: + state (dict[str, Any]): A serializable dictionary containing the store state + """ + # add/update data without clearing - safer for concurrent access + # multiple pipelines can load the same state without interfering with each other + self._data.update(state) diff --git a/tests/unit/experimental/pipeline/test_component.py b/tests/unit/experimental/pipeline/test_component.py index ba39aebaf..3d77803b4 100644 --- a/tests/unit/experimental/pipeline/test_component.py +++ b/tests/unit/experimental/pipeline/test_component.py @@ -18,7 +18,12 @@ from neo4j_graphrag.experimental.pipeline import Component from neo4j_graphrag.experimental.pipeline.types.context import RunContext -from .components import ComponentMultiply, ComponentMultiplyWithContext, IntResultModel +from .components import ( + ComponentMultiply, + ComponentMultiplyWithContext, + IntResultModel, + StatefulComponent, +) def test_component_inputs() -> None: @@ -87,3 +92,13 @@ class WrongComponent(Component): "You must implement either `run` or `run_with_context` in Component 'WrongComponent'" in str(e) ) + + +def test_stateful_component_serialize_and_load_state() -> None: + c = StatefulComponent() + c.counter = 42 + state = c.serialize_state() + assert state == {"counter": 42} + c.counter = 0 + c.load_state({"counter": 99}) + assert c.counter == 99 diff --git a/tests/unit/experimental/pipeline/test_pipeline.py b/tests/unit/experimental/pipeline/test_pipeline.py index d37fc65fc..6f30d33ea 100644 --- a/tests/unit/experimental/pipeline/test_pipeline.py +++ b/tests/unit/experimental/pipeline/test_pipeline.py @@ -17,7 +17,7 @@ import asyncio import datetime import tempfile -from typing import Sized +from typing import Sized, Any, Dict from unittest import mock from unittest.mock import AsyncMock, call, patch @@ -38,6 +38,7 @@ ComponentMultiply, ComponentNoParam, ComponentPassThrough, + StatefulComponent, StringResultModel, SlowComponentMultiply, ) @@ -590,3 +591,133 @@ async def callback(event: Event) -> None: events.append(e) assert len(events) == 2 assert len(pipe.callbacks) == 1 + + +@pytest.mark.asyncio +async def test_pipeline_state_dump_and_load() -> None: + pipe = Pipeline() + c1 = StatefulComponent() + c2 = StatefulComponent() + pipe.add_component(c1, "a") + pipe.add_component(c2, "b") + pipe.connect("a", "b", {"value": "a.result"}) + result = await pipe.run({"a": {"value": 5}}) + c1.counter = 123 # simulate state change + state = pipe.dump_state(result.run_id) + c1.counter = 0 + pipe.load_state(state) + assert c1.counter == 123 + + +@pytest.fixture +def stateful_pipeline() -> tuple[Pipeline, StatefulComponent, StatefulComponent]: + """Fixture that creates a pipeline with two stateful components connected in sequence.""" + pipe = Pipeline() + c1 = StatefulComponent() + c2 = StatefulComponent() + pipe.add_component(c1, "a") + pipe.add_component(c2, "b") + pipe.connect("a", "b", {"value": "a.result"}) + return pipe, c1, c2 + + +@pytest.mark.asyncio +async def test_pipeline_run_until_and_resume_from( + stateful_pipeline: tuple[Pipeline, StatefulComponent, StatefulComponent], +) -> None: + pipe, _, c2 = stateful_pipeline + + # run until first component + state = await pipe.run_until({"a": {"value": 7}}, stop_after="a") + + # verify that the state contains the run_id + assert "run_id" in state + run_id = state["run_id"] + assert isinstance(run_id, str) + assert len(run_id) > 0 + + # change c2 to a new instance to test resume + c2_new = StatefulComponent() + pipe.set_component("b", c2_new) + + # resume from b + result = await pipe.resume_from(state, {"a": {"value": 5}}, start_from="b") + + # verify we're using the same run_id and correct value propagation + assert result.run_id == run_id + # c2_new should have received value=7 (from a.result) + assert result.result["b"]["result"] == 7 + + +@pytest.mark.asyncio +async def test_pipeline_run_until_with_state_file( + stateful_pipeline: tuple[Pipeline, StatefulComponent, StatefulComponent], +) -> None: + pipe, _, c2 = stateful_pipeline + + # Create a temporary file for state + with tempfile.NamedTemporaryFile(suffix=".json") as state_file: + # run until first component and save state to file + state = await pipe.run_until( + {"a": {"value": 7}}, stop_after="a", state_file=state_file.name + ) + + # verify that the state contains the run_id + assert "run_id" in state + run_id = state["run_id"] + assert isinstance(run_id, str) + assert len(run_id) > 0 + + # change c2 to a new instance to test resume + c2_new = StatefulComponent() + pipe.set_component("b", c2_new) + + # resume from b using state file + result = await pipe.resume_from( + None, {"a": {"value": 5}}, start_from="b", state_file=state_file.name + ) + + # verify we're using the same run_id and correct value propagation + assert result.run_id == run_id + # c2_new should have received value=7 (from a.result) + assert result.result["b"]["result"] == 7 + + +@pytest.mark.asyncio +async def test_pipeline_resume_from_missing_state() -> None: + pipe = Pipeline() + c1 = StatefulComponent() + pipe.add_component(c1, "a") + + with pytest.raises(ValueError, match="No state provided for resume_from."): + await pipe.resume_from(None, {"a": {"value": 5}}, start_from="a") + + +@pytest.mark.asyncio +async def test_pipeline_resume_from_invalid_state() -> None: + pipe = Pipeline() + c1 = StatefulComponent() + pipe.add_component(c1, "a") + + invalid_state: Dict[str, Dict[str, Any]] = { + "components": {}, + "store": {}, + "final_results": {}, + } # missing run_id + + with pytest.raises( + ValueError, match="No run_id found in state. Cannot resume execution." + ): + await pipe.resume_from(invalid_state, {"a": {"value": 5}}, start_from="a") + + +@pytest.mark.asyncio +async def test_pipeline_resume_from_nonexistent_state_file() -> None: + pipe = Pipeline() + c1 = StatefulComponent() + pipe.add_component(c1, "a") + + with pytest.raises(FileNotFoundError): + await pipe.resume_from( + None, {"a": {"value": 5}}, start_from="a", state_file="nonexistent.json" + )