24
24
25
25
try :
26
26
from neo4j_viz import Node , Relationship , VisualizationGraph as NeoVizGraph
27
+
27
28
HAS_NEO4J_VIZ = True
28
29
except ImportError :
29
30
HAS_NEO4J_VIZ = False
@@ -186,13 +187,13 @@ def draw(
186
187
self , path : str , layout : str = "force" , hide_unused_outputs : bool = True
187
188
) -> Any :
188
189
"""Draw the pipeline graph using neo4j-viz.
189
-
190
+
190
191
Args:
191
192
path (str): Path to save the visualization. If the path ends with .html, it will save an HTML file.
192
193
Otherwise, it will save a PNG image.
193
194
layout (str): Layout algorithm to use. Default is "force".
194
195
hide_unused_outputs (bool): Whether to hide unused outputs. Default is True.
195
-
196
+
196
197
Returns:
197
198
Any: The visualization object.
198
199
"""
@@ -211,10 +212,10 @@ def draw(
211
212
212
213
def get_neo4j_viz_graph (self , hide_unused_outputs : bool = True ) -> NeoVizGraph :
213
214
"""Create a neo4j-viz visualization graph from the pipeline.
214
-
215
+
215
216
Args:
216
217
hide_unused_outputs (bool): Whether to hide unused outputs. Default is True.
217
-
218
+
218
219
Returns:
219
220
NeoVizGraph: The neo4j-viz visualization graph.
220
221
"""
@@ -224,12 +225,12 @@ def get_neo4j_viz_graph(self, hide_unused_outputs: bool = True) -> NeoVizGraph:
224
225
"Install it with 'pip install neo4j-viz'."
225
226
)
226
227
self .validate_parameter_mapping ()
227
-
228
+
228
229
nodes = []
229
230
relationships = []
230
231
node_ids = {}
231
232
node_counter = 0
232
-
233
+
233
234
# Create nodes for each component
234
235
for n , node in self ._nodes .items ():
235
236
comp_inputs = "," .join (
@@ -242,11 +243,11 @@ def get_neo4j_viz_graph(self, hide_unused_outputs: bool = True) -> NeoVizGraph:
242
243
id = node_counter ,
243
244
caption = f"{ node .component .__class__ .__name__ } : { n } ({ comp_inputs } )" ,
244
245
size = 20 , # Component nodes are larger
245
- color = "#4C8BF5" # Blue for component nodes
246
+ color = "#4C8BF5" , # Blue for component nodes
246
247
)
247
248
)
248
249
node_counter += 1
249
-
250
+
250
251
# Create nodes for each output field
251
252
for o in node .component .component_outputs :
252
253
param_node_name = f"{ n } .{ o } "
@@ -256,19 +257,17 @@ def get_neo4j_viz_graph(self, hide_unused_outputs: bool = True) -> NeoVizGraph:
256
257
id = node_counter ,
257
258
caption = o ,
258
259
size = 10 , # Output nodes are smaller
259
- color = "#34A853" # Green for output nodes
260
+ color = "#34A853" , # Green for output nodes
260
261
)
261
262
)
262
263
# Connect component to its output
263
264
relationships .append (
264
265
Relationship (
265
- source = node_ids [n ],
266
- target = node_ids [param_node_name ],
267
- caption = ""
266
+ source = node_ids [n ], target = node_ids [param_node_name ], caption = ""
268
267
)
269
268
)
270
269
node_counter += 1
271
-
270
+
272
271
# Create edges between components and their inputs
273
272
for component_name , params in self .param_mapping .items ():
274
273
for param , mapping in params .items ():
@@ -278,27 +277,27 @@ def get_neo4j_viz_graph(self, hide_unused_outputs: bool = True) -> NeoVizGraph:
278
277
source_output_node = f"{ source_component } .{ source_param_name } "
279
278
else :
280
279
source_output_node = source_component
281
-
280
+
282
281
if source_output_node in node_ids and component_name in node_ids :
283
282
relationships .append (
284
283
Relationship (
285
284
source = node_ids [source_output_node ],
286
285
target = node_ids [component_name ],
287
286
caption = param ,
288
- color = "#EA4335" # Red for parameter connections
287
+ color = "#EA4335" , # Red for parameter connections
289
288
)
290
289
)
291
-
290
+
292
291
# Filter unused outputs if requested
293
292
if hide_unused_outputs :
294
293
used_nodes = set ()
295
294
for rel in relationships :
296
295
used_nodes .add (rel .source )
297
296
used_nodes .add (rel .target )
298
-
297
+
299
298
filtered_nodes = [node for node in nodes if node .id in used_nodes ]
300
299
return NeoVizGraph (nodes = filtered_nodes , relationships = relationships )
301
-
300
+
302
301
return NeoVizGraph (nodes = nodes , relationships = relationships )
303
302
304
303
def add_component (self , component : Component , name : str ) -> None :
0 commit comments