Skip to content

Commit f31aafc

Browse files
committed
Replace pygraphviz with neo4j-viz for pipeline visualization
This commit replaces pygraphviz with neo4j-viz for pipeline visualization, providing a more interactive HTML-based visualization experience. The changes include: - Updated the Pipeline.draw() method to generate HTML output using neo4j-viz - Added a new get_neo4j_viz_graph() method while maintaining backward compatibility - Updated dependencies in pyproject.toml to use neo4j-viz instead of pygraphviz - Updated documentation and examples to reflect the change from PNG to HTML output - Updated unit tests to work with the new visualization implementation - Added stub file for neo4j-viz to make mypy happy Delete tmp file Update README Make internal method private Remove unnecessary formatting Better error message when trying to use without having it installed Add header to new file
1 parent 85eaa5b commit f31aafc

File tree

8 files changed

+1404
-994
lines changed

8 files changed

+1404
-994
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
### Changed
1414

1515
- Improved log output readability in Retrievers and GraphRAG and added embedded vector to retriever result metadata for debugging.
16+
- Switched from pygraphviz to neo4j-viz
17+
- Renders interactive graph now on HTML instead of PNG
18+
- Removed `get_pygraphviz_graph` method
1619

1720
### Fixed
1821

docs/source/user_guide_pipeline.rst

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -111,25 +111,26 @@ Pipelines can be visualized using the `draw` method:
111111
pipe = Pipeline()
112112
# ... define components and connections
113113
114-
pipe.draw("pipeline.png")
114+
pipe.draw("pipeline.html")
115115
116-
Here is an example pipeline rendering:
116+
Here is an example pipeline rendering as an interactive HTML visualization:
117117

118-
.. image:: images/pipeline_no_unused_outputs.png
119-
:alt: Pipeline visualisation with hidden outputs if unused
118+
.. code:: python
120119
120+
# To view the visualization in a browser
121+
import webbrowser
122+
webbrowser.open("pipeline.html")
121123
122124
By default, output fields which are not mapped to any component are hidden. They
123-
can be added to the canvas by setting `hide_unused_outputs` to `False`:
125+
can be added to the visualization by setting `hide_unused_outputs` to `False`:
124126

125127
.. code:: python
126128
127-
pipe.draw("pipeline.png", hide_unused_outputs=False)
128-
129-
Here is an example of final result:
130-
131-
.. image:: images/pipeline_full.png
132-
:alt: Pipeline visualisation
129+
pipe.draw("pipeline_full.html", hide_unused_outputs=False)
130+
131+
# To view the full visualization in a browser
132+
import webbrowser
133+
webbrowser.open("pipeline_full.html")
133134
134135
135136
************************

examples/customize/build_graph/pipeline/visualization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,5 +54,5 @@ async def run(self, number: IntDataModel) -> IntDataModel:
5454
pipe.connect("times_two", "addition", {"a": "times_two.value"})
5555
pipe.connect("times_ten", "addition", {"b": "times_ten.value"})
5656
pipe.connect("addition", "save", {"number": "addition"})
57-
pipe.draw("graph.png")
58-
pipe.draw("graph_full.png", hide_unused_outputs=False)
57+
pipe.draw("graph.html")
58+
pipe.draw("graph_full.html", hide_unused_outputs=False)

poetry.lock

Lines changed: 1205 additions & 932 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,7 @@ pyyaml = "^6.0.2"
3838
types-pyyaml = "^6.0.12.20240917"
3939
# optional deps
4040
langchain-text-splitters = {version = "^0.3.0", optional = true }
41-
pygraphviz = [
42-
{version = "^1.13.0", python = ">=3.10,<4.0.0", optional = true},
43-
{version = "^1.0.0", python = "<3.10", optional = true}
44-
]
41+
neo4j-viz = {version = "^0.2.2", optional = true }
4542
weaviate-client = {version = "^4.6.1", optional = true }
4643
pinecone-client = {version = "^4.1.0", optional = true }
4744
google-cloud-aiplatform = {version = "^1.66.0", optional = true }
@@ -68,6 +65,7 @@ sphinx = { version = "^7.2.6", python = "^3.9" }
6865
langchain-openai = {version = "^0.2.2", optional = true }
6966
langchain-huggingface = {version = "^0.1.0", optional = true }
7067
enum-tools = {extras = ["sphinx"], version = "^0.12.0"}
68+
neo4j-viz = "^0.2.2"
7169

7270
[tool.poetry.extras]
7371
weaviate = ["weaviate-client"]
@@ -79,9 +77,9 @@ ollama = ["ollama"]
7977
openai = ["openai"]
8078
mistralai = ["mistralai"]
8179
qdrant = ["qdrant-client"]
82-
kg_creation_tools = ["pygraphviz"]
80+
kg_creation_tools = ["neo4j-viz"]
8381
sentence-transformers = ["sentence-transformers"]
84-
experimental = ["langchain-text-splitters", "pygraphviz", "llama-index"]
82+
experimental = ["langchain-text-splitters", "neo4j-viz", "llama-index"]
8583
examples = ["langchain-openai", "langchain-huggingface"]
8684
nlp = ["spacy"]
8785
fuzzy-matching = ["rapidfuzz"]
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
# #
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from typing import Any, Dict, List, Optional, Union
17+
18+
class Node:
19+
id: Union[str, int]
20+
caption: Optional[str] = None
21+
size: Optional[float] = None
22+
properties: Optional[Dict[str, Any]] = None
23+
24+
def __init__(
25+
self,
26+
id: Union[str, int],
27+
caption: Optional[str] = None,
28+
size: Optional[float] = None,
29+
properties: Optional[Dict[str, Any]] = None,
30+
**kwargs: Any,
31+
) -> None: ...
32+
33+
class Relationship:
34+
source: Union[str, int]
35+
target: Union[str, int]
36+
caption: Optional[str] = None
37+
properties: Optional[Dict[str, Any]] = None
38+
39+
def __init__(
40+
self,
41+
source: Union[str, int],
42+
target: Union[str, int],
43+
caption: Optional[str] = None,
44+
properties: Optional[Dict[str, Any]] = None,
45+
**kwargs: Any,
46+
) -> None: ...
47+
48+
class VisualizationGraph:
49+
nodes: List[Node]
50+
relationships: List[Relationship]
51+
52+
def __init__(
53+
self, nodes: List[Node], relationships: List[Relationship]
54+
) -> None: ...

src/neo4j_graphrag/experimental/pipeline/pipeline.py

Lines changed: 115 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,17 @@
1818
import warnings
1919
from collections import defaultdict
2020
from timeit import default_timer
21-
from typing import Any, Optional, AsyncGenerator
21+
from typing import Any, Optional, AsyncGenerator, cast
2222
import asyncio
2323

2424
from neo4j_graphrag.utils.logging import prettify
2525

2626
try:
27-
import pygraphviz as pgv
27+
from neo4j_viz import Node, Relationship, VisualizationGraph
28+
29+
neo4j_viz_available = True
2830
except ImportError:
29-
pgv = None
31+
neo4j_viz_available = False
3032

3133
from pydantic import BaseModel
3234

@@ -198,53 +200,132 @@ def show_as_dict(self) -> dict[str, Any]:
198200
def draw(
199201
self, path: str, layout: str = "dot", hide_unused_outputs: bool = True
200202
) -> Any:
201-
G = self.get_pygraphviz_graph(hide_unused_outputs)
202-
G.layout(layout)
203-
G.draw(path)
203+
"""Render the pipeline graph to an HTML file at the specified path"""
204+
G = self._get_neo4j_viz_graph(hide_unused_outputs)
205+
206+
# Write the visualization to an HTML file
207+
with open(path, "w") as f:
208+
f.write(G.render().data)
209+
210+
return G
204211

205-
def get_pygraphviz_graph(self, hide_unused_outputs: bool = True) -> pgv.AGraph:
206-
if pgv is None:
212+
def _get_neo4j_viz_graph(
213+
self, hide_unused_outputs: bool = True
214+
) -> VisualizationGraph:
215+
"""Generate a neo4j-viz visualization of the pipeline graph"""
216+
if not neo4j_viz_available:
207217
raise ImportError(
208-
"Could not import pygraphviz. "
209-
"Follow installation instruction in pygraphviz documentation "
210-
"to get it up and running on your system."
218+
"Could not import neo4j-viz. Install it with 'pip install \"neo4j-graphrag[experimental]\"'"
211219
)
220+
212221
self.validate_parameter_mapping()
213-
G = pgv.AGraph(strict=False, directed=True)
214-
# create a node for each component
215-
for n, node in self._nodes.items():
216-
comp_inputs = ",".join(
222+
223+
nodes = []
224+
relationships = []
225+
node_ids = {} # Map node names to their numeric IDs
226+
next_id = 0
227+
228+
# Create nodes for each component
229+
for n, pipeline_node in self._nodes.items():
230+
comp_inputs = ", ".join(
217231
f"{i}: {d['annotation']}"
218-
for i, d in node.component.component_inputs.items()
232+
for i, d in pipeline_node.component.component_inputs.items()
219233
)
220-
G.add_node(
221-
n,
222-
node_type="component",
223-
shape="rectangle",
224-
label=f"{node.component.__class__.__name__}: {n}({comp_inputs})",
234+
235+
node_ids[n] = next_id
236+
label = f"{pipeline_node.component.__class__.__name__}: {n}({comp_inputs})"
237+
238+
# Create Node with properties parameter
239+
viz_node = Node( # type: ignore
240+
id=next_id,
241+
caption=label,
242+
size=20,
243+
properties={"node_type": "component"},
225244
)
226-
# create a node for each output field and connect them it to its component
227-
for o in node.component.component_outputs:
245+
# Cast the node to Any before adding it to the list
246+
nodes.append(cast(Any, viz_node))
247+
next_id += 1
248+
249+
# Create nodes for each output field
250+
for o in pipeline_node.component.component_outputs:
228251
param_node_name = f"{n}.{o}"
229-
G.add_node(param_node_name, label=o, node_type="output")
230-
G.add_edge(n, param_node_name)
231-
# then we create the edges between a component output
232-
# and the component it gets added to
252+
253+
# Skip if we're hiding unused outputs and it's not used
254+
if hide_unused_outputs:
255+
# Check if this output is used as a source in any parameter mapping
256+
is_used = False
257+
for params in self.param_mapping.values():
258+
for mapping in params.values():
259+
source_component = mapping["component"]
260+
source_param_name = mapping.get("param")
261+
if source_component == n and source_param_name == o:
262+
is_used = True
263+
break
264+
if is_used:
265+
break
266+
267+
if not is_used:
268+
continue
269+
270+
node_ids[param_node_name] = next_id
271+
# Create Node with properties parameter
272+
output_node = Node( # type: ignore
273+
id=next_id,
274+
caption=o,
275+
size=15,
276+
properties={"node_type": "output"},
277+
)
278+
# Cast the node to Any before adding it to the list
279+
nodes.append(cast(Any, output_node))
280+
281+
# Connect component to its output
282+
# Add type ignore comment to suppress mypy errors
283+
rel = Relationship( # type: ignore
284+
source=node_ids[n],
285+
target=node_ids[param_node_name],
286+
properties={"type": "HAS_OUTPUT"},
287+
)
288+
relationships.append(rel)
289+
next_id += 1
290+
291+
# Create edges between components based on parameter mapping
233292
for component_name, params in self.param_mapping.items():
234293
for param, mapping in params.items():
235294
source_component = mapping["component"]
236295
source_param_name = mapping.get("param")
296+
237297
if source_param_name:
238298
source_output_node = f"{source_component}.{source_param_name}"
239299
else:
240300
source_output_node = source_component
241-
G.add_edge(source_output_node, component_name, label=param)
242-
# remove outputs that are not mapped
243-
if hide_unused_outputs:
244-
for n in G.nodes():
245-
if n.attr["node_type"] == "output" and G.out_degree(n) == 0: # type: ignore
246-
G.remove_node(n)
247-
return G
301+
302+
if source_output_node in node_ids and component_name in node_ids:
303+
# Add type ignore comment to suppress mypy errors
304+
rel = Relationship( # type: ignore
305+
source=node_ids[source_output_node],
306+
target=node_ids[component_name],
307+
caption=param,
308+
properties={"type": "CONNECTS_TO"},
309+
)
310+
relationships.append(rel)
311+
312+
# Cast the constructor to Any, then cast the result back to VisualizationGraph
313+
viz_graph = cast(Any, VisualizationGraph)(
314+
nodes=nodes, relationships=relationships
315+
)
316+
# Cast the result back to the expected return type
317+
return cast(VisualizationGraph, viz_graph)
318+
319+
def get_pygraphviz_graph(self, hide_unused_outputs: bool = True) -> Any:
320+
"""Legacy method for backward compatibility.
321+
Uses neo4j-viz instead of pygraphviz.
322+
"""
323+
warnings.warn(
324+
"get_pygraphviz_graph is deprecated, use draw instead",
325+
DeprecationWarning,
326+
stacklevel=2,
327+
)
328+
return self._get_neo4j_viz_graph(hide_unused_outputs)
248329

249330
def add_component(self, component: Component, name: str) -> None:
250331
"""Add a new component. Components are uniquely identified

tests/unit/experimental/pipeline/test_pipeline.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -380,39 +380,39 @@ async def test_pipeline_async() -> None:
380380
assert pipeline_result[1].result == {"add": {"result": 12}}
381381

382382

383-
def test_pipeline_to_pgv() -> None:
383+
def test_pipeline_to_viz() -> None:
384384
pipe = Pipeline()
385385
component_a = ComponentAdd()
386386
component_b = ComponentMultiply()
387387
pipe.add_component(component_a, "a")
388388
pipe.add_component(component_b, "b")
389389
pipe.connect("a", "b", {"number1": "a.result"})
390-
g = pipe.get_pygraphviz_graph()
390+
g = pipe._get_neo4j_viz_graph()
391391
# 3 nodes:
392392
# - 2 components 'a' and 'b'
393393
# - 1 output 'a.result'
394-
assert len(g.nodes()) == 3
395-
g = pipe.get_pygraphviz_graph(hide_unused_outputs=False)
394+
assert len(g.nodes) == 3
395+
g = pipe._get_neo4j_viz_graph(hide_unused_outputs=False)
396396
# 4 nodes:
397397
# - 2 components 'a' and 'b'
398398
# - 2 output 'a.result' and 'b.result'
399-
assert len(g.nodes()) == 4
399+
assert len(g.nodes) == 4
400400

401401

402402
def test_pipeline_draw() -> None:
403403
pipe = Pipeline()
404404
pipe.add_component(ComponentAdd(), "add")
405-
t = tempfile.NamedTemporaryFile()
405+
t = tempfile.NamedTemporaryFile(suffix=".html")
406406
pipe.draw(t.name)
407407
content = t.file.read()
408408
assert len(content) > 0
409409

410410

411-
@patch("neo4j_graphrag.experimental.pipeline.pipeline.pgv", None)
412-
def test_pipeline_draw_missing_pygraphviz_dep() -> None:
411+
@patch("neo4j_graphrag.experimental.pipeline.pipeline.neo4j_viz_available", False)
412+
def test_pipeline_draw_missing_neo4j_viz_dep() -> None:
413413
pipe = Pipeline()
414414
pipe.add_component(ComponentAdd(), "add")
415-
t = tempfile.NamedTemporaryFile()
415+
t = tempfile.NamedTemporaryFile(suffix=".html")
416416
with pytest.raises(ImportError):
417417
pipe.draw(t.name)
418418

0 commit comments

Comments
 (0)