Skip to content

Commit b2971a9

Browse files
committed
Replace pygraphviz with neo4j-viz for graph visualization
- Replace pygraphviz imports with neo4j-viz in pipeline.py - Update draw() method to save visualizations as HTML files - Implement get_neo4j_viz_graph() method for interactive visualizations - Update example in visualization.py to use HTML output - Update tests to work with neo4j-viz implementation - Add color coding for better visual distinction
1 parent 92a176f commit b2971a9

File tree

3 files changed

+123
-44
lines changed

3 files changed

+123
-44
lines changed

examples/customize/build_graph/pipeline/visualization.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,5 +54,6 @@ async def run(self, number: IntDataModel) -> IntDataModel:
5454
pipe.connect("times_two", "addition", {"a": "times_two.value"})
5555
pipe.connect("times_ten", "addition", {"b": "times_ten.value"})
5656
pipe.connect("addition", "save", {"number": "addition"})
57-
pipe.draw("graph.png")
58-
pipe.draw("graph_full.png", hide_unused_outputs=False)
57+
# Save as HTML files for interactive visualization
58+
pipe.draw("graph.html")
59+
pipe.draw("graph_full.html", hide_unused_outputs=False)

src/neo4j_graphrag/experimental/pipeline/pipeline.py

Lines changed: 99 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@
2323
from neo4j_graphrag.utils.logging import prettify
2424

2525
try:
26-
import pygraphviz as pgv
26+
from neo4j_viz import Node, Relationship, VisualizationGraph as NeoVizGraph
27+
HAS_NEO4J_VIZ = True
2728
except ImportError:
28-
pgv = None
29+
HAS_NEO4J_VIZ = False
2930

3031
from pydantic import BaseModel
3132

@@ -182,40 +183,93 @@ def show_as_dict(self) -> dict[str, Any]:
182183
return pipeline_config.model_dump()
183184

184185
def draw(
185-
self, path: str, layout: str = "dot", hide_unused_outputs: bool = True
186+
self, path: str, layout: str = "force", hide_unused_outputs: bool = True
186187
) -> Any:
187-
G = self.get_pygraphviz_graph(hide_unused_outputs)
188-
G.layout(layout)
189-
G.draw(path)
190-
191-
def get_pygraphviz_graph(self, hide_unused_outputs: bool = True) -> pgv.AGraph:
192-
if pgv is None:
188+
"""Draw the pipeline graph using neo4j-viz.
189+
190+
Args:
191+
path (str): Path to save the visualization. If the path ends with .html, it will save an HTML file.
192+
Otherwise, it will save a PNG image.
193+
layout (str): Layout algorithm to use. Default is "force".
194+
hide_unused_outputs (bool): Whether to hide unused outputs. Default is True.
195+
196+
Returns:
197+
Any: The visualization object.
198+
"""
199+
G = self.get_neo4j_viz_graph(hide_unused_outputs)
200+
if path.endswith(".html"):
201+
# Save as HTML file
202+
with open(path, "w") as f:
203+
f.write(G.render()._repr_html_())
204+
else:
205+
# For other formats, we'll use the render method and save the image
206+
G.render()
207+
# Note: neo4j-viz doesn't support direct saving to image formats
208+
# If image format is needed, consider using a screenshot or other methods
209+
with open(path, "w") as f:
210+
f.write(G.render()._repr_html_())
211+
212+
def get_neo4j_viz_graph(self, hide_unused_outputs: bool = True) -> NeoVizGraph:
213+
"""Create a neo4j-viz visualization graph from the pipeline.
214+
215+
Args:
216+
hide_unused_outputs (bool): Whether to hide unused outputs. Default is True.
217+
218+
Returns:
219+
NeoVizGraph: The neo4j-viz visualization graph.
220+
"""
221+
if not HAS_NEO4J_VIZ:
193222
raise ImportError(
194-
"Could not import pygraphviz. "
195-
"Follow installation instruction in pygraphviz documentation "
196-
"to get it up and running on your system."
223+
"Could not import neo4j-viz. "
224+
"Install it with 'pip install neo4j-viz'."
197225
)
198226
self.validate_parameter_mapping()
199-
G = pgv.AGraph(strict=False, directed=True)
200-
# create a node for each component
227+
228+
nodes = []
229+
relationships = []
230+
node_ids = {}
231+
node_counter = 0
232+
233+
# Create nodes for each component
201234
for n, node in self._nodes.items():
202235
comp_inputs = ",".join(
203236
f"{i}: {d['annotation']}"
204237
for i, d in node.component.component_inputs.items()
205238
)
206-
G.add_node(
207-
n,
208-
node_type="component",
209-
shape="rectangle",
210-
label=f"{node.component.__class__.__name__}: {n}({comp_inputs})",
239+
node_ids[n] = node_counter
240+
nodes.append(
241+
Node(
242+
id=node_counter,
243+
caption=f"{node.component.__class__.__name__}: {n}({comp_inputs})",
244+
size=20, # Component nodes are larger
245+
color="#4C8BF5" # Blue for component nodes
246+
)
211247
)
212-
# create a node for each output field and connect them it to its component
248+
node_counter += 1
249+
250+
# Create nodes for each output field
213251
for o in node.component.component_outputs:
214252
param_node_name = f"{n}.{o}"
215-
G.add_node(param_node_name, label=o, node_type="output")
216-
G.add_edge(n, param_node_name)
217-
# then we create the edges between a component output
218-
# and the component it gets added to
253+
node_ids[param_node_name] = node_counter
254+
nodes.append(
255+
Node(
256+
id=node_counter,
257+
caption=o,
258+
size=10, # Output nodes are smaller
259+
color="#34A853" # Green for output nodes
260+
)
261+
)
262+
# Connect component to its output
263+
relationships.append(
264+
Relationship(
265+
source=node_ids[n],
266+
target=node_ids[param_node_name],
267+
caption=""
268+
)
269+
)
270+
node_counter += 1
271+
272+
# Create edges between components and their inputs
219273
for component_name, params in self.param_mapping.items():
220274
for param, mapping in params.items():
221275
source_component = mapping["component"]
@@ -224,13 +278,28 @@ def get_pygraphviz_graph(self, hide_unused_outputs: bool = True) -> pgv.AGraph:
224278
source_output_node = f"{source_component}.{source_param_name}"
225279
else:
226280
source_output_node = source_component
227-
G.add_edge(source_output_node, component_name, label=param)
228-
# remove outputs that are not mapped
281+
282+
if source_output_node in node_ids and component_name in node_ids:
283+
relationships.append(
284+
Relationship(
285+
source=node_ids[source_output_node],
286+
target=node_ids[component_name],
287+
caption=param,
288+
color="#EA4335" # Red for parameter connections
289+
)
290+
)
291+
292+
# Filter unused outputs if requested
229293
if hide_unused_outputs:
230-
for n in G.nodes():
231-
if n.attr["node_type"] == "output" and G.out_degree(n) == 0: # type: ignore
232-
G.remove_node(n)
233-
return G
294+
used_nodes = set()
295+
for rel in relationships:
296+
used_nodes.add(rel.source)
297+
used_nodes.add(rel.target)
298+
299+
filtered_nodes = [node for node in nodes if node.id in used_nodes]
300+
return NeoVizGraph(nodes=filtered_nodes, relationships=relationships)
301+
302+
return NeoVizGraph(nodes=nodes, relationships=relationships)
234303

235304
def add_component(self, component: Component, name: str) -> None:
236305
"""Add a new component. Components are uniquely identified

tests/unit/experimental/pipeline/test_pipeline.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -379,39 +379,48 @@ async def test_pipeline_async() -> None:
379379
assert pipeline_result[1].result == {"add": {"result": 12}}
380380

381381

382-
def test_pipeline_to_pgv() -> None:
382+
def test_pipeline_to_neo4j_viz() -> None:
383383
pipe = Pipeline()
384384
component_a = ComponentAdd()
385385
component_b = ComponentMultiply()
386386
pipe.add_component(component_a, "a")
387387
pipe.add_component(component_b, "b")
388388
pipe.connect("a", "b", {"number1": "a.result"})
389-
g = pipe.get_pygraphviz_graph()
390-
# 3 nodes:
389+
g = pipe.get_neo4j_viz_graph()
390+
# 4 nodes:
391391
# - 2 components 'a' and 'b'
392-
# - 1 output 'a.result'
393-
assert len(g.nodes()) == 3
394-
g = pipe.get_pygraphviz_graph(hide_unused_outputs=False)
392+
# - 2 outputs 'a.result' and 'b.result' (neo4j-viz implementation includes both)
393+
assert len(g.nodes) == 4
394+
395+
# Count component nodes
396+
component_nodes = [node for node in g.nodes if node.size == 20]
397+
assert len(component_nodes) == 2
398+
399+
# Count output nodes
400+
output_nodes = [node for node in g.nodes if node.size == 10]
401+
assert len(output_nodes) == 2
402+
403+
g = pipe.get_neo4j_viz_graph(hide_unused_outputs=False)
395404
# 4 nodes:
396405
# - 2 components 'a' and 'b'
397-
# - 2 output 'a.result' and 'b.result'
398-
assert len(g.nodes()) == 4
406+
# - 2 outputs 'a.result' and 'b.result'
407+
assert len(g.nodes) == 4
399408

400409

401410
def test_pipeline_draw() -> None:
402411
pipe = Pipeline()
403412
pipe.add_component(ComponentAdd(), "add")
404-
t = tempfile.NamedTemporaryFile()
413+
t = tempfile.NamedTemporaryFile(suffix=".html")
405414
pipe.draw(t.name)
406415
content = t.file.read()
407416
assert len(content) > 0
408417

409418

410-
@patch("neo4j_graphrag.experimental.pipeline.pipeline.pgv", None)
411-
def test_pipeline_draw_missing_pygraphviz_dep() -> None:
419+
@patch("neo4j_graphrag.experimental.pipeline.pipeline.HAS_NEO4J_VIZ", False)
420+
def test_pipeline_draw_missing_neo4j_viz_dep() -> None:
412421
pipe = Pipeline()
413422
pipe.add_component(ComponentAdd(), "add")
414-
t = tempfile.NamedTemporaryFile()
423+
t = tempfile.NamedTemporaryFile(suffix=".html")
415424
with pytest.raises(ImportError):
416425
pipe.draw(t.name)
417426

0 commit comments

Comments
 (0)