Skip to content

Pipeline state dump and load #352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
4a93761
Serialize/deserialize component state
NathalieCharbel May 15, 2025
f940506
Run pipeline until/Resume pipeline from
NathalieCharbel May 15, 2025
12d2e26
Remove in memory storage support for pipeline state
NathalieCharbel Jun 4, 2025
8a00015
Add pipeline run_id and ability to save and load state from json file
NathalieCharbel Jun 4, 2025
5f0c892
Add unit tests
NathalieCharbel Jun 4, 2025
22722b8
Update changelog and docs
NathalieCharbel Jun 4, 2025
9a57046
Ruff
NathalieCharbel Jun 4, 2025
d1f7389
Remove state management for component
NathalieCharbel Jun 10, 2025
8bc8ac0
Remove resume_from and run_until and reuse existing run interface
NathalieCharbel Jun 10, 2025
f76e5ca
Add dump and load to InMemoryStore
NathalieCharbel Jun 10, 2025
2cc5b99
Ruff
NathalieCharbel Jun 10, 2025
2a819ea
Add ability ro load and dump state by run_id
NathalieCharbel Jun 11, 2025
472eaf1
Allow orchestrator to run use a run_id from previous run
NathalieCharbel Jun 11, 2025
13254dd
Refactor pipeline and validate loaded state
NathalieCharbel Jun 11, 2025
a9413d3
Refactor pipeline run_id management
NathalieCharbel Jun 13, 2025
64bdc66
Ensure previous run_ids are kept in store
NathalieCharbel Jun 13, 2025
d020a7e
Ensure resume run with different run_ids
NathalieCharbel Jun 13, 2025
c97bb97
Cleanup stores
NathalieCharbel Jun 16, 2025
74e7db0
Ensure proper handling of previous run_ids
NathalieCharbel Jun 16, 2025
0ca4c60
Update changelog and docs
NathalieCharbel Jun 16, 2025
2c091c2
Fix orchestrator's way of handling tasks on complete and transitions …
NathalieCharbel Jul 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 119 additions & 15 deletions src/neo4j_graphrag/experimental/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def __init__(
}
"""
self.missing_inputs: dict[str, list[str]] = defaultdict()
self._current_run_id: Optional[str] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can not be saved in the Pipeline instance, since concurrent runs will override it.

Copy link
Contributor Author

@NathalieCharbel NathalieCharbel Jun 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will move it back to dump() function. I think we should maintain creating different run_ids even after resuming the same pipeline and dump the state based on previous ones. we could keep track of run_ids of the same pipeline in the state. This should resolve the concurrency issue, right?


@classmethod
def from_template(
Expand Down Expand Up @@ -390,22 +391,34 @@ def validate_parameter_mapping(self) -> None:
self.validate_parameter_mapping_for_task(task)
self.is_validated = True

def validate_input_data(self, data: dict[str, Any]) -> bool:
def validate_input_data(
self, data: dict[str, Any], from_: Optional[str] = None
) -> bool:
"""Performs parameter and data validation before running the pipeline:
- Check parameters defined in the connect method
- Make sure the missing parameters are present in the input `data` dict.

Args:
data (dict[str, Any]): input data to use for validation
(usually from Pipeline.run)
from_ (Optional[str]): If provided, only validate components that will actually execute
starting from this component

Raises:
PipelineDefinitionError if any parameter mapping is invalid or if a
parameter is missing.
"""
if not self.is_validated:
self.validate_parameter_mapping()

# determine which components need validation
components_to_validate = self._get_components_to_validate(from_)

for task in self._nodes.values():
# skip validation for components that won't execute
if task.name not in components_to_validate:
continue

if task.name not in self.param_mapping:
self.validate_parameter_mapping_for_task(task)
missing_params = self.missing_inputs[task.name]
Expand All @@ -417,6 +430,37 @@ def validate_input_data(self, data: dict[str, Any]) -> bool:
)
return True

def _get_components_to_validate(self, from_: Optional[str] = None) -> set[str]:
"""Determine which components need validation based on execution context.

Args:
from_ (Optional[str]): Starting component for execution

Returns:
set[str]: Set of component names that need validation
"""
if from_ is None:
# no from_ specified, validate all components
return set(self._nodes.keys())

# when from_ is specified, only validate components that will actually execute
# this includes the from_ component and all its downstream dependencies
components_to_validate = set()

def add_downstream_components(component_name: str) -> None:
"""Recursively add a component and all its downstream dependencies"""
if component_name in components_to_validate:
return # Already processed
components_to_validate.add(component_name)

# add all components that depend on this one
for edge in self.next_edges(component_name):
add_downstream_components(edge.end)

# start from the specified component and add all downstream
add_downstream_components(from_)
return components_to_validate

def validate_parameter_mapping_for_task(self, task: TaskPipelineNode) -> bool:
"""Make sure that all the parameter mapping for a given task are valid.
Does not consider user input yet.
Expand Down Expand Up @@ -582,10 +626,16 @@ async def run(
logger.debug("PIPELINE START")
start_time = default_timer()
self.invalidate()
self.validate_input_data(data)
self.validate_input_data(data, from_)

# create orchestrator with appropriate start_from and stop_after params
orchestrator = Orchestrator(self, stop_after=until, start_from=from_)
# if current run_id exists (from loaded state), use it to continue the same run
orchestrator = Orchestrator(
self, stop_after=until, start_from=from_, run_id=self._current_run_id
)

# Track the current run_id
self._current_run_id = orchestrator.run_id

logger.debug(f"PIPELINE ORCHESTRATOR: {orchestrator.run_id}")
await orchestrator.run(data)
Expand All @@ -600,20 +650,23 @@ async def run(
result=await self.get_final_results(orchestrator.run_id),
)

def dump_state(self, run_id: str) -> Dict[str, Any]:
def dump_state(self) -> Dict[str, Any]:
"""Dump the current state of the pipeline to a serializable dictionary.

Args:
run_id: The run_id that was used when the pipeline was executed

Returns:
Dict[str, Any]: A serializable dictionary containing the pipeline state

Raises:
ValueError: If no pipeline run has been executed yet
"""
if self._current_run_id is None:
raise ValueError(
"No pipeline run has been executed yet. Cannot dump state without a run_id."
)

pipeline_state: Dict[str, Any] = {
"run_id": run_id,
"store": self.store.dump(),
"final_results": self.final_results.dump(),
"is_validated": self.is_validated,
"run_id": self._current_run_id,
"store": self.store.dump(self._current_run_id),
}
return pipeline_state

Expand All @@ -622,15 +675,66 @@ def load_state(self, state: Dict[str, Any]) -> None:

Args:
state (dict[str, Any]): Previously serialized pipeline state

Raises:
ValueError: If the state is invalid or incompatible with current pipeline
"""
if "run_id" not in state:
raise ValueError("Invalid state: missing run_id")

run_id = state["run_id"]

# validate pipeline compatibility
self._validate_state_compatibility(state)

# set the current run_id
self._current_run_id = run_id

# load pipeline state attributes
if "is_validated" in state:
self.is_validated = state["is_validated"]

# load store data
if "store" in state:
self.store.load(state["store"])
self.store.load(run_id, state["store"])

def _validate_state_compatibility(self, state: Dict[str, Any]) -> None:
"""Validate that the loaded state is compatible with the current pipeline.

This checks that the components defined in the pipeline match those
that were present when the state was saved.

Args:
state (dict[str, Any]): The state to validate

# load final results
if "final_results" in state:
self.final_results.load(state["final_results"])
Raises:
ValueError: If the state is incompatible with the current pipeline
"""
if "store" not in state:
return # no store data to validate

store_data = state["store"]
if not store_data:
return # empty store, nothing to validate

# extract component names from the store keys
# keys are in format: "run_id:component_name" or "run_id:component_name:suffix"
stored_components = set()
for key in store_data.keys():
parts = key.split(":")
if len(parts) >= 2:
component_name = parts[1]
stored_components.add(component_name)

# get current pipeline component names
current_components = set(self._nodes.keys())

# check if stored components are a subset of current components
# this allows for the pipeline to have additional components, but not missing ones
missing_components = stored_components - current_components
if missing_components:
raise ValueError(
f"State is incompatible with current pipeline. "
f"Missing components: {sorted(missing_components)}. "
f"Current pipeline components: {sorted(current_components)}"
)
Loading