Skip to content

Commit b90003b

Browse files
authored
Raise error on wrong parameter mapping during pipeline definition (#124)
* Raise error when param is mapped twice or is not a valid input * Pipeline invalidation method, to reinitialize param mapping and missing params dict
1 parent 75138c1 commit b90003b

File tree

3 files changed

+80
-2
lines changed

3 files changed

+80
-2
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121

2222
### Changed
2323
- Pipeline run method now return a PipelineResult object.
24+
- Improved parameter validation for pipelines (#124). Pipeline now raise an error before a run starts if:
25+
- the same parameter is mapped twice
26+
- or a parameter is defined in the mapping but is not a valid component input
2427

2528

2629
## 0.5.0

src/neo4j_graphrag/experimental/pipeline/pipeline.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def add_component(self, component: Component, name: str) -> None:
430430
task = TaskPipelineNode(name, component)
431431
self.add_node(task)
432432
# invalidate the pipeline if it was already validated
433-
self.is_validated = False
433+
self.invalidate()
434434

435435
def set_component(self, name: str, component: Component) -> None:
436436
"""Replace a component with another. If 'name' is not yet in the pipeline,
@@ -439,7 +439,7 @@ def set_component(self, name: str, component: Component) -> None:
439439
task = TaskPipelineNode(name, component)
440440
self.set_node(task)
441441
# invalidate the pipeline if it was already validated
442-
self.is_validated = False
442+
self.invalidate()
443443

444444
def connect(
445445
self,
@@ -475,7 +475,12 @@ def connect(
475475
if self.is_cyclic():
476476
raise PipelineDefinitionError("Cyclic graph are not allowed")
477477
# invalidate the pipeline if it was already validated
478+
self.invalidate()
479+
480+
def invalidate(self) -> None:
478481
self.is_validated = False
482+
self.param_mapping = defaultdict(dict)
483+
self.missing_inputs = defaultdict()
479484

480485
def validate_parameter_mapping(self) -> None:
481486
"""Go through the graph and make sure parameter mapping is valid
@@ -520,6 +525,8 @@ def validate_parameter_mapping_for_task(self, task: TaskPipelineNode) -> bool:
520525
521526
Considering the naming {param => target (component, [output_parameter]) },
522527
the mapping is valid if:
528+
- 'param' is a valid input for task
529+
- 'param' has not already been mapped
523530
- The target component exists in the pipeline and, if specified, the
524531
target output parameter is a valid field in the target component's
525532
result model.
@@ -543,6 +550,14 @@ def validate_parameter_mapping_for_task(self, task: TaskPipelineNode) -> bool:
543550
# check that the previous component is actually returning
544551
# the mapped parameter
545552
for param, path in edge_inputs.items():
553+
if param in self.param_mapping[task.name]:
554+
raise PipelineDefinitionError(
555+
f"Parameter '{param}' already mapped to {self.param_mapping[task.name][param]}"
556+
)
557+
if param not in task.component.component_inputs:
558+
raise PipelineDefinitionError(
559+
f"Parameter '{param}' is not a valid input for component '{task.name}' of type '{task.component.__class__.__name__}'"
560+
)
546561
try:
547562
source_component_name, param_name = path.split(".")
548563
except ValueError:

tests/unit/experimental/pipeline/test_pipeline.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,33 @@ def test_pipeline_parameter_validation_one_component_all_good() -> None:
9797
assert is_valid is True
9898

9999

100+
def test_pipeline_invalidate() -> None:
101+
pipe = Pipeline()
102+
pipe.is_validated = True
103+
pipe.param_mapping = {"a": {"key": {"component": "component", "param": "param"}}}
104+
pipe.missing_inputs = {"a": ["other_key"]}
105+
pipe.invalidate()
106+
assert pipe.is_validated is False
107+
assert len(pipe.param_mapping) == 0
108+
assert len(pipe.missing_inputs) == 0
109+
110+
111+
def test_pipeline_parameter_validation_called_twice() -> None:
112+
pipe = Pipeline()
113+
component_a = ComponentPassThrough()
114+
component_b = ComponentPassThrough()
115+
pipe.add_component(component_a, "a")
116+
pipe.add_component(component_b, "b")
117+
pipe.connect("a", "b", {"value": "a.result"})
118+
is_valid = pipe.validate_parameter_mapping_for_task(pipe.get_node_by_name("b"))
119+
assert is_valid is True
120+
with pytest.raises(PipelineDefinitionError):
121+
pipe.validate_parameter_mapping_for_task(pipe.get_node_by_name("b"))
122+
pipe.invalidate()
123+
is_valid = pipe.validate_parameter_mapping_for_task(pipe.get_node_by_name("b"))
124+
assert is_valid is True
125+
126+
100127
def test_pipeline_parameter_validation_one_component_input_param_missing() -> None:
101128
pipe = Pipeline()
102129
component_a = ComponentPassThrough()
@@ -105,6 +132,39 @@ def test_pipeline_parameter_validation_one_component_input_param_missing() -> No
105132
assert pipe.missing_inputs["a"] == ["value"]
106133

107134

135+
def test_pipeline_parameter_validation_param_mapped_twice() -> None:
136+
pipe = Pipeline()
137+
component_a = ComponentPassThrough()
138+
component_b = ComponentPassThrough()
139+
component_c = ComponentPassThrough()
140+
pipe.add_component(component_a, "a")
141+
pipe.add_component(component_b, "b")
142+
pipe.add_component(component_c, "c")
143+
pipe.connect("a", "c", {"value": "a.result"})
144+
pipe.connect("b", "c", {"value": "b.result"})
145+
with pytest.raises(PipelineDefinitionError) as excinfo:
146+
pipe.validate_parameter_mapping_for_task(pipe.get_node_by_name("c"))
147+
assert (
148+
"Parameter 'value' already mapped to {'component': 'a', 'param': 'result'}"
149+
in str(excinfo)
150+
)
151+
152+
153+
def test_pipeline_parameter_validation_unexpected_input() -> None:
154+
pipe = Pipeline()
155+
component_a = ComponentPassThrough()
156+
component_b = ComponentPassThrough()
157+
pipe.add_component(component_a, "a")
158+
pipe.add_component(component_b, "b")
159+
pipe.connect("a", "b", {"unexpected_input_name": "a.result"})
160+
with pytest.raises(PipelineDefinitionError) as excinfo:
161+
pipe.validate_parameter_mapping_for_task(pipe.get_node_by_name("b"))
162+
assert (
163+
"Parameter 'unexpected_input_name' is not a valid input for component 'b' of type 'ComponentPassThrough'"
164+
in str(excinfo)
165+
)
166+
167+
108168
def test_pipeline_parameter_validation_connected_components_input() -> None:
109169
"""Parameter for component 'b' comes from the pipeline inputs"""
110170
pipe = Pipeline()

0 commit comments

Comments
 (0)