23
23
from neo4j_graphrag .utils .logging import prettify
24
24
25
25
try :
26
- import pygraphviz as pgv
26
+ from neo4j_viz import Node , Relationship , VisualizationGraph as NeoVizGraph
27
+ HAS_NEO4J_VIZ = True
27
28
except ImportError :
28
- pgv = None
29
+ HAS_NEO4J_VIZ = False
29
30
30
31
from pydantic import BaseModel
31
32
@@ -182,40 +183,93 @@ def show_as_dict(self) -> dict[str, Any]:
182
183
return pipeline_config .model_dump ()
183
184
184
185
def draw (
185
- self , path : str , layout : str = "dot " , hide_unused_outputs : bool = True
186
+ self , path : str , layout : str = "force " , hide_unused_outputs : bool = True
186
187
) -> Any :
187
- G = self .get_pygraphviz_graph (hide_unused_outputs )
188
- G .layout (layout )
189
- G .draw (path )
190
-
191
- def get_pygraphviz_graph (self , hide_unused_outputs : bool = True ) -> pgv .AGraph :
192
- if pgv is None :
188
+ """Draw the pipeline graph using neo4j-viz.
189
+
190
+ Args:
191
+ path (str): Path to save the visualization. If the path ends with .html, it will save an HTML file.
192
+ Otherwise, it will save a PNG image.
193
+ layout (str): Layout algorithm to use. Default is "force".
194
+ hide_unused_outputs (bool): Whether to hide unused outputs. Default is True.
195
+
196
+ Returns:
197
+ Any: The visualization object.
198
+ """
199
+ G = self .get_neo4j_viz_graph (hide_unused_outputs )
200
+ if path .endswith (".html" ):
201
+ # Save as HTML file
202
+ with open (path , "w" ) as f :
203
+ f .write (G .render ()._repr_html_ ())
204
+ else :
205
+ # For other formats, we'll use the render method and save the image
206
+ G .render ()
207
+ # Note: neo4j-viz doesn't support direct saving to image formats
208
+ # If image format is needed, consider using a screenshot or other methods
209
+ with open (path , "w" ) as f :
210
+ f .write (G .render ()._repr_html_ ())
211
+
212
+ def get_neo4j_viz_graph (self , hide_unused_outputs : bool = True ) -> NeoVizGraph :
213
+ """Create a neo4j-viz visualization graph from the pipeline.
214
+
215
+ Args:
216
+ hide_unused_outputs (bool): Whether to hide unused outputs. Default is True.
217
+
218
+ Returns:
219
+ NeoVizGraph: The neo4j-viz visualization graph.
220
+ """
221
+ if not HAS_NEO4J_VIZ :
193
222
raise ImportError (
194
- "Could not import pygraphviz. "
195
- "Follow installation instruction in pygraphviz documentation "
196
- "to get it up and running on your system."
223
+ "Could not import neo4j-viz. "
224
+ "Install it with 'pip install neo4j-viz'."
197
225
)
198
226
self .validate_parameter_mapping ()
199
- G = pgv .AGraph (strict = False , directed = True )
200
- # create a node for each component
227
+
228
+ nodes = []
229
+ relationships = []
230
+ node_ids = {}
231
+ node_counter = 0
232
+
233
+ # Create nodes for each component
201
234
for n , node in self ._nodes .items ():
202
235
comp_inputs = "," .join (
203
236
f"{ i } : { d ['annotation' ]} "
204
237
for i , d in node .component .component_inputs .items ()
205
238
)
206
- G .add_node (
207
- n ,
208
- node_type = "component" ,
209
- shape = "rectangle" ,
210
- label = f"{ node .component .__class__ .__name__ } : { n } ({ comp_inputs } )" ,
239
+ node_ids [n ] = node_counter
240
+ nodes .append (
241
+ Node (
242
+ id = node_counter ,
243
+ caption = f"{ node .component .__class__ .__name__ } : { n } ({ comp_inputs } )" ,
244
+ size = 20 , # Component nodes are larger
245
+ color = "#4C8BF5" # Blue for component nodes
246
+ )
211
247
)
212
- # create a node for each output field and connect them it to its component
248
+ node_counter += 1
249
+
250
+ # Create nodes for each output field
213
251
for o in node .component .component_outputs :
214
252
param_node_name = f"{ n } .{ o } "
215
- G .add_node (param_node_name , label = o , node_type = "output" )
216
- G .add_edge (n , param_node_name )
217
- # then we create the edges between a component output
218
- # and the component it gets added to
253
+ node_ids [param_node_name ] = node_counter
254
+ nodes .append (
255
+ Node (
256
+ id = node_counter ,
257
+ caption = o ,
258
+ size = 10 , # Output nodes are smaller
259
+ color = "#34A853" # Green for output nodes
260
+ )
261
+ )
262
+ # Connect component to its output
263
+ relationships .append (
264
+ Relationship (
265
+ source = node_ids [n ],
266
+ target = node_ids [param_node_name ],
267
+ caption = ""
268
+ )
269
+ )
270
+ node_counter += 1
271
+
272
+ # Create edges between components and their inputs
219
273
for component_name , params in self .param_mapping .items ():
220
274
for param , mapping in params .items ():
221
275
source_component = mapping ["component" ]
@@ -224,13 +278,28 @@ def get_pygraphviz_graph(self, hide_unused_outputs: bool = True) -> pgv.AGraph:
224
278
source_output_node = f"{ source_component } .{ source_param_name } "
225
279
else :
226
280
source_output_node = source_component
227
- G .add_edge (source_output_node , component_name , label = param )
228
- # remove outputs that are not mapped
281
+
282
+ if source_output_node in node_ids and component_name in node_ids :
283
+ relationships .append (
284
+ Relationship (
285
+ source = node_ids [source_output_node ],
286
+ target = node_ids [component_name ],
287
+ caption = param ,
288
+ color = "#EA4335" # Red for parameter connections
289
+ )
290
+ )
291
+
292
+ # Filter unused outputs if requested
229
293
if hide_unused_outputs :
230
- for n in G .nodes ():
231
- if n .attr ["node_type" ] == "output" and G .out_degree (n ) == 0 : # type: ignore
232
- G .remove_node (n )
233
- return G
294
+ used_nodes = set ()
295
+ for rel in relationships :
296
+ used_nodes .add (rel .source )
297
+ used_nodes .add (rel .target )
298
+
299
+ filtered_nodes = [node for node in nodes if node .id in used_nodes ]
300
+ return NeoVizGraph (nodes = filtered_nodes , relationships = relationships )
301
+
302
+ return NeoVizGraph (nodes = nodes , relationships = relationships )
234
303
235
304
def add_component (self , component : Component , name : str ) -> None :
236
305
"""Add a new component. Components are uniquely identified
0 commit comments