Skip to content

Commit 9a92182

Browse files
committed
Fix formatting issues in pipeline files
1 parent b2971a9 commit 9a92182

File tree

2 files changed

+20
-21
lines changed

2 files changed

+20
-21
lines changed

src/neo4j_graphrag/experimental/pipeline/pipeline.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
try:
2626
from neo4j_viz import Node, Relationship, VisualizationGraph as NeoVizGraph
27+
2728
HAS_NEO4J_VIZ = True
2829
except ImportError:
2930
HAS_NEO4J_VIZ = False
@@ -186,13 +187,13 @@ def draw(
186187
self, path: str, layout: str = "force", hide_unused_outputs: bool = True
187188
) -> Any:
188189
"""Draw the pipeline graph using neo4j-viz.
189-
190+
190191
Args:
191192
path (str): Path to save the visualization. If the path ends with .html, it will save an HTML file.
192193
Otherwise, it will save a PNG image.
193194
layout (str): Layout algorithm to use. Default is "force".
194195
hide_unused_outputs (bool): Whether to hide unused outputs. Default is True.
195-
196+
196197
Returns:
197198
Any: The visualization object.
198199
"""
@@ -211,10 +212,10 @@ def draw(
211212

212213
def get_neo4j_viz_graph(self, hide_unused_outputs: bool = True) -> NeoVizGraph:
213214
"""Create a neo4j-viz visualization graph from the pipeline.
214-
215+
215216
Args:
216217
hide_unused_outputs (bool): Whether to hide unused outputs. Default is True.
217-
218+
218219
Returns:
219220
NeoVizGraph: The neo4j-viz visualization graph.
220221
"""
@@ -224,12 +225,12 @@ def get_neo4j_viz_graph(self, hide_unused_outputs: bool = True) -> NeoVizGraph:
224225
"Install it with 'pip install neo4j-viz'."
225226
)
226227
self.validate_parameter_mapping()
227-
228+
228229
nodes = []
229230
relationships = []
230231
node_ids = {}
231232
node_counter = 0
232-
233+
233234
# Create nodes for each component
234235
for n, node in self._nodes.items():
235236
comp_inputs = ",".join(
@@ -242,11 +243,11 @@ def get_neo4j_viz_graph(self, hide_unused_outputs: bool = True) -> NeoVizGraph:
242243
id=node_counter,
243244
caption=f"{node.component.__class__.__name__}: {n}({comp_inputs})",
244245
size=20, # Component nodes are larger
245-
color="#4C8BF5" # Blue for component nodes
246+
color="#4C8BF5", # Blue for component nodes
246247
)
247248
)
248249
node_counter += 1
249-
250+
250251
# Create nodes for each output field
251252
for o in node.component.component_outputs:
252253
param_node_name = f"{n}.{o}"
@@ -256,19 +257,17 @@ def get_neo4j_viz_graph(self, hide_unused_outputs: bool = True) -> NeoVizGraph:
256257
id=node_counter,
257258
caption=o,
258259
size=10, # Output nodes are smaller
259-
color="#34A853" # Green for output nodes
260+
color="#34A853", # Green for output nodes
260261
)
261262
)
262263
# Connect component to its output
263264
relationships.append(
264265
Relationship(
265-
source=node_ids[n],
266-
target=node_ids[param_node_name],
267-
caption=""
266+
source=node_ids[n], target=node_ids[param_node_name], caption=""
268267
)
269268
)
270269
node_counter += 1
271-
270+
272271
# Create edges between components and their inputs
273272
for component_name, params in self.param_mapping.items():
274273
for param, mapping in params.items():
@@ -278,27 +277,27 @@ def get_neo4j_viz_graph(self, hide_unused_outputs: bool = True) -> NeoVizGraph:
278277
source_output_node = f"{source_component}.{source_param_name}"
279278
else:
280279
source_output_node = source_component
281-
280+
282281
if source_output_node in node_ids and component_name in node_ids:
283282
relationships.append(
284283
Relationship(
285284
source=node_ids[source_output_node],
286285
target=node_ids[component_name],
287286
caption=param,
288-
color="#EA4335" # Red for parameter connections
287+
color="#EA4335", # Red for parameter connections
289288
)
290289
)
291-
290+
292291
# Filter unused outputs if requested
293292
if hide_unused_outputs:
294293
used_nodes = set()
295294
for rel in relationships:
296295
used_nodes.add(rel.source)
297296
used_nodes.add(rel.target)
298-
297+
299298
filtered_nodes = [node for node in nodes if node.id in used_nodes]
300299
return NeoVizGraph(nodes=filtered_nodes, relationships=relationships)
301-
300+
302301
return NeoVizGraph(nodes=nodes, relationships=relationships)
303302

304303
def add_component(self, component: Component, name: str) -> None:

tests/unit/experimental/pipeline/test_pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -391,15 +391,15 @@ def test_pipeline_to_neo4j_viz() -> None:
391391
# - 2 components 'a' and 'b'
392392
# - 2 outputs 'a.result' and 'b.result' (neo4j-viz implementation includes both)
393393
assert len(g.nodes) == 4
394-
394+
395395
# Count component nodes
396396
component_nodes = [node for node in g.nodes if node.size == 20]
397397
assert len(component_nodes) == 2
398-
398+
399399
# Count output nodes
400400
output_nodes = [node for node in g.nodes if node.size == 10]
401401
assert len(output_nodes) == 2
402-
402+
403403
g = pipe.get_neo4j_viz_graph(hide_unused_outputs=False)
404404
# 4 nodes:
405405
# - 2 components 'a' and 'b'

0 commit comments

Comments
 (0)