-
Notifications
You must be signed in to change notification settings - Fork 102
Pipeline state dump and load #352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 14 commits
4a93761
f940506
12d2e26
8a00015
5f0c892
22722b8
9a57046
d1f7389
8bc8ac0
f76e5ca
2cc5b99
2a819ea
472eaf1
13254dd
a9413d3
64bdc66
d020a7e
c97bb97
74e7db0
0ca4c60
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,7 @@ | |
import warnings | ||
from collections import defaultdict | ||
from timeit import default_timer | ||
from typing import Any, Optional, AsyncGenerator | ||
from typing import Any, Optional, AsyncGenerator, Dict | ||
import asyncio | ||
|
||
from neo4j_graphrag.utils.logging import prettify | ||
|
@@ -140,6 +140,7 @@ def __init__( | |
} | ||
""" | ||
self.missing_inputs: dict[str, list[str]] = defaultdict() | ||
self._current_run_id: Optional[str] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can not be saved in the Pipeline instance, since concurrent runs will override it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will move it back to |
||
|
||
@classmethod | ||
def from_template( | ||
|
@@ -390,22 +391,34 @@ def validate_parameter_mapping(self) -> None: | |
self.validate_parameter_mapping_for_task(task) | ||
self.is_validated = True | ||
|
||
def validate_input_data(self, data: dict[str, Any]) -> bool: | ||
def validate_input_data( | ||
self, data: dict[str, Any], from_: Optional[str] = None | ||
) -> bool: | ||
"""Performs parameter and data validation before running the pipeline: | ||
- Check parameters defined in the connect method | ||
- Make sure the missing parameters are present in the input `data` dict. | ||
|
||
Args: | ||
data (dict[str, Any]): input data to use for validation | ||
(usually from Pipeline.run) | ||
from_ (Optional[str]): If provided, only validate components that will actually execute | ||
starting from this component | ||
|
||
Raises: | ||
PipelineDefinitionError if any parameter mapping is invalid or if a | ||
parameter is missing. | ||
""" | ||
if not self.is_validated: | ||
self.validate_parameter_mapping() | ||
|
||
# determine which components need validation | ||
components_to_validate = self._get_components_to_validate(from_) | ||
|
||
for task in self._nodes.values(): | ||
# skip validation for components that won't execute | ||
if task.name not in components_to_validate: | ||
continue | ||
|
||
if task.name not in self.param_mapping: | ||
self.validate_parameter_mapping_for_task(task) | ||
missing_params = self.missing_inputs[task.name] | ||
|
@@ -417,6 +430,37 @@ def validate_input_data(self, data: dict[str, Any]) -> bool: | |
) | ||
return True | ||
|
||
def _get_components_to_validate(self, from_: Optional[str] = None) -> set[str]: | ||
"""Determine which components need validation based on execution context. | ||
|
||
Args: | ||
from_ (Optional[str]): Starting component for execution | ||
|
||
Returns: | ||
set[str]: Set of component names that need validation | ||
""" | ||
if from_ is None: | ||
# no from_ specified, validate all components | ||
return set(self._nodes.keys()) | ||
|
||
# when from_ is specified, only validate components that will actually execute | ||
# this includes the from_ component and all its downstream dependencies | ||
components_to_validate = set() | ||
|
||
def add_downstream_components(component_name: str) -> None: | ||
"""Recursively add a component and all its downstream dependencies""" | ||
if component_name in components_to_validate: | ||
return # Already processed | ||
components_to_validate.add(component_name) | ||
|
||
# add all components that depend on this one | ||
for edge in self.next_edges(component_name): | ||
add_downstream_components(edge.end) | ||
|
||
# start from the specified component and add all downstream | ||
add_downstream_components(from_) | ||
return components_to_validate | ||
|
||
def validate_parameter_mapping_for_task(self, task: TaskPipelineNode) -> bool: | ||
"""Make sure that all the parameter mapping for a given task are valid. | ||
Does not consider user input yet. | ||
|
@@ -563,19 +607,134 @@ async def event_stream(event: Event) -> None: | |
if event_queue_getter_task and not event_queue_getter_task.done(): | ||
event_queue_getter_task.cancel() | ||
|
||
async def run(self, data: dict[str, Any]) -> PipelineResult: | ||
async def run( | ||
self, | ||
data: dict[str, Any], | ||
from_: Optional[str] = None, | ||
until: Optional[str] = None, | ||
) -> PipelineResult: | ||
"""Run the pipeline, optionally from a specific component or until a specific component. | ||
|
||
Args: | ||
data (dict[str, Any]): The input data for the pipeline | ||
from_ (str | None, optional): If provided, start execution from this component. Defaults to None. | ||
until (str | None, optional): If provided, stop execution after this component. Defaults to None. | ||
|
||
Returns: | ||
PipelineResult: The result of the pipeline execution | ||
""" | ||
logger.debug("PIPELINE START") | ||
start_time = default_timer() | ||
self.invalidate() | ||
self.validate_input_data(data) | ||
orchestrator = Orchestrator(self) | ||
self.validate_input_data(data, from_) | ||
|
||
# create orchestrator with appropriate start_from and stop_after params | ||
# if current run_id exists (from loaded state), use it to continue the same run | ||
orchestrator = Orchestrator( | ||
self, stop_after=until, start_from=from_, run_id=self._current_run_id | ||
) | ||
|
||
# Track the current run_id | ||
self._current_run_id = orchestrator.run_id | ||
|
||
logger.debug(f"PIPELINE ORCHESTRATOR: {orchestrator.run_id}") | ||
await orchestrator.run(data) | ||
|
||
end_time = default_timer() | ||
logger.debug( | ||
f"PIPELINE FINISHED {orchestrator.run_id} in {end_time - start_time}s" | ||
) | ||
|
||
return PipelineResult( | ||
run_id=orchestrator.run_id, | ||
result=await self.get_final_results(orchestrator.run_id), | ||
) | ||
|
||
def dump_state(self) -> Dict[str, Any]: | ||
stellasia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Dump the current state of the pipeline to a serializable dictionary. | ||
|
||
Returns: | ||
Dict[str, Any]: A serializable dictionary containing the pipeline state | ||
|
||
Raises: | ||
ValueError: If no pipeline run has been executed yet | ||
""" | ||
if self._current_run_id is None: | ||
raise ValueError( | ||
"No pipeline run has been executed yet. Cannot dump state without a run_id." | ||
) | ||
|
||
pipeline_state: Dict[str, Any] = { | ||
"run_id": self._current_run_id, | ||
"store": self.store.dump(self._current_run_id), | ||
} | ||
return pipeline_state | ||
|
||
def load_state(self, state: Dict[str, Any]) -> None: | ||
"""Load pipeline state from a serialized dictionary. | ||
|
||
Args: | ||
state (dict[str, Any]): Previously serialized pipeline state | ||
|
||
Raises: | ||
ValueError: If the state is invalid or incompatible with current pipeline | ||
""" | ||
if "run_id" not in state: | ||
raise ValueError("Invalid state: missing run_id") | ||
|
||
run_id = state["run_id"] | ||
|
||
# validate pipeline compatibility | ||
self._validate_state_compatibility(state) | ||
|
||
# set the current run_id | ||
self._current_run_id = run_id | ||
|
||
# load pipeline state attributes | ||
if "is_validated" in state: | ||
self.is_validated = state["is_validated"] | ||
|
||
# load store data | ||
if "store" in state: | ||
self.store.load(run_id, state["store"]) | ||
|
||
def _validate_state_compatibility(self, state: Dict[str, Any]) -> None: | ||
"""Validate that the loaded state is compatible with the current pipeline. | ||
|
||
This checks that the components defined in the pipeline match those | ||
that were present when the state was saved. | ||
|
||
Args: | ||
state (dict[str, Any]): The state to validate | ||
|
||
Raises: | ||
ValueError: If the state is incompatible with the current pipeline | ||
""" | ||
if "store" not in state: | ||
return # no store data to validate | ||
|
||
store_data = state["store"] | ||
if not store_data: | ||
return # empty store, nothing to validate | ||
|
||
# extract component names from the store keys | ||
# keys are in format: "run_id:component_name" or "run_id:component_name:suffix" | ||
stored_components = set() | ||
for key in store_data.keys(): | ||
parts = key.split(":") | ||
if len(parts) >= 2: | ||
component_name = parts[1] | ||
stored_components.add(component_name) | ||
|
||
# get current pipeline component names | ||
current_components = set(self._nodes.keys()) | ||
|
||
# check if stored components are a subset of current components | ||
# this allows for the pipeline to have additional components, but not missing ones | ||
missing_components = stored_components - current_components | ||
stellasia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if missing_components: | ||
raise ValueError( | ||
f"State is incompatible with current pipeline. " | ||
f"Missing components: {sorted(missing_components)}. " | ||
f"Current pipeline components: {sorted(current_components)}" | ||
) |
Uh oh!
There was an error while loading. Please reload this page.