Skip to content

Commit d0528f4

Browse files
authored
Add check to not use deprecated Cypher syntax when Neo4j version is >= 5.23.0 (#183)
* Add check to not use deprecated Cypher syntax when Neo4j version is >= 5.23.0 * Update CHANGELOG * Add variable scope query in Hybrid Retriever based on neo4j version * Include E2E test to test for deprecation warning from deprecated Cypher subquery syntax * Resolve mypy errors * Add neo4j:latest to pr and scheduled E2E tests
1 parent d0fe5ea commit d0528f4

File tree

15 files changed

+500
-66
lines changed

15 files changed

+500
-66
lines changed

.github/workflows/pr-e2e-tests.yaml

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,8 @@ jobs:
1616
strategy:
1717
matrix:
1818
python-version: ['3.9', '3.12']
19-
neo4j-version:
20-
- 5
21-
neo4j-edition:
22-
- enterprise
19+
neo4j-tag:
20+
- 'latest'
2321
services:
2422
t2v-transformers:
2523
image: cr.weaviate.io/semitechnologies/transformers-inference:sentence-transformers-all-MiniLM-L6-v2-onnx
@@ -37,7 +35,7 @@ jobs:
3735
- 8080:8080
3836
- 50051:50051
3937
neo4j:
40-
image: neo4j:${{ matrix.neo4j-version }}-${{ matrix.neo4j-edition }}
38+
image: neo4j:${{ matrix.neo4j-tag }}
4139
env:
4240
NEO4J_AUTH: neo4j/password
4341
NEO4J_ACCEPT_LICENSE_AGREEMENT: 'eval'
@@ -93,7 +91,7 @@ jobs:
9391
- name: Run tests
9492
shell: bash
9593
run: |
96-
if [[ "${{ matrix.neo4j-edition }}" == "community" ]]; then
94+
if [[ "${{ matrix.neo4j-tag }}" == "latest" || "${{ matrix.neo4j-tag }}" == *-community ]]; then
9795
poetry run pytest -m 'not enterprise_only' ./tests/e2e
9896
else
9997
poetry run pytest ./tests/e2e

.github/workflows/scheduled-e2e-tests.yaml

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,10 @@ jobs:
1313
strategy:
1414
matrix:
1515
python-version: ['3.9', '3.10', '3.11', '3.12']
16-
neo4j-version:
17-
- 5
18-
neo4j-edition:
19-
- community
20-
- enterprise
16+
neo4j-tag:
17+
- '5-community'
18+
- '5-enterprise'
19+
- 'latest'
2120
services:
2221
t2v-transformers:
2322
image: cr.weaviate.io/semitechnologies/transformers-inference:sentence-transformers-all-MiniLM-L6-v2-onnx
@@ -41,7 +40,7 @@ jobs:
4140
username: ${{ secrets.DOCKERHUB_USERNAME }}
4241
password: ${{ secrets.DOCKERHUB_TOKEN }}
4342
neo4j:
44-
image: neo4j:${{ matrix.neo4j-version }}-${{ matrix.neo4j-edition }}
43+
image: neo4j:${{ matrix.neo4j-tag }}
4544
env:
4645
NEO4J_AUTH: neo4j/password
4746
NEO4J_ACCEPT_LICENSE_AGREEMENT: 'eval'
@@ -100,7 +99,7 @@ jobs:
10099
- name: Run tests
101100
shell: bash
102101
run: |
103-
if [[ "${{ matrix.neo4j-edition }}" == "community" ]]; then
102+
if [[ "${{ matrix.neo4j-tag }}" == "latest" || "${{ matrix.neo4j-tag }}" == *-community ]]; then
104103
poetry run pytest -m 'not enterprise_only' ./tests/e2e
105104
else
106105
poetry run pytest ./tests/e2e

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
### Added
66
- Made `relations` and `potential_schema` optional in `SchemaBuilder`.
7+
- Added a check to prevent the use of deprecated Cypher syntax for Neo4j versions 5.23.0 and above.
78

89
## 1.1.0
910

src/neo4j_graphrag/experimental/components/kg_writer.py

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,12 @@
3333
Neo4jRelationship,
3434
)
3535
from neo4j_graphrag.experimental.pipeline.component import Component, DataModel
36-
from neo4j_graphrag.neo4j_queries import UPSERT_NODE_QUERY, UPSERT_RELATIONSHIP_QUERY
36+
from neo4j_graphrag.neo4j_queries import (
37+
UPSERT_NODE_QUERY,
38+
UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE,
39+
UPSERT_RELATIONSHIP_QUERY,
40+
UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE,
41+
)
3742

3843
logger = logging.getLogger(__name__)
3944

@@ -113,6 +118,7 @@ def __init__(
113118
self.neo4j_database = neo4j_database
114119
self.batch_size = batch_size
115120
self.max_concurrency = max_concurrency
121+
self.is_version_5_23_or_above = self._check_if_version_5_23_or_above()
116122

117123
def _db_setup(self) -> None:
118124
# create index on __Entity__.id
@@ -147,7 +153,12 @@ def _upsert_nodes(self, nodes: list[Neo4jNode]) -> None:
147153
nodes (list[Neo4jNode]): The nodes batch to upsert into the database.
148154
"""
149155
parameters = {"rows": self._nodes_to_rows(nodes)}
150-
self.driver.execute_query(UPSERT_NODE_QUERY, parameters_=parameters)
156+
if self.is_version_5_23_or_above:
157+
self.driver.execute_query(
158+
UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE, parameters_=parameters
159+
)
160+
else:
161+
self.driver.execute_query(UPSERT_NODE_QUERY, parameters_=parameters)
151162

152163
async def _async_upsert_nodes(
153164
self,
@@ -161,7 +172,32 @@ async def _async_upsert_nodes(
161172
"""
162173
async with sem:
163174
parameters = {"rows": self._nodes_to_rows(nodes)}
164-
await self.driver.execute_query(UPSERT_NODE_QUERY, parameters_=parameters)
175+
await self.driver.execute_query(
176+
UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE, parameters_=parameters
177+
)
178+
179+
def _get_version(self) -> tuple[int, ...]:
180+
records, _, _ = self.driver.execute_query(
181+
"CALL dbms.components()", database_=self.neo4j_database
182+
)
183+
version = records[0]["versions"][0]
184+
# Drop everything after the '-' first
185+
version_main, *_ = version.split("-")
186+
# Convert each number between '.' into int
187+
version_tuple = tuple(map(int, version_main.split(".")))
188+
# If no patch version, consider it's 0
189+
if len(version_tuple) < 3:
190+
version_tuple = (*version_tuple, 0)
191+
return version_tuple
192+
193+
def _check_if_version_5_23_or_above(self) -> bool:
194+
"""
195+
Check if the connected Neo4j database version supports the required features.
196+
197+
Sets a flag if the connected Neo4j version is 5.23 or above.
198+
"""
199+
version_tuple = self._get_version()
200+
return version_tuple >= (5, 23, 0)
165201

166202
def _upsert_relationships(self, rels: list[Neo4jRelationship]) -> None:
167203
"""Upserts a single relationship into the Neo4j database.
@@ -170,7 +206,12 @@ def _upsert_relationships(self, rels: list[Neo4jRelationship]) -> None:
170206
rels (list[Neo4jRelationship]): The relationships batch to upsert into the database.
171207
"""
172208
parameters = {"rows": [rel.model_dump() for rel in rels]}
173-
self.driver.execute_query(UPSERT_RELATIONSHIP_QUERY, parameters_=parameters)
209+
if self.is_version_5_23_or_above:
210+
self.driver.execute_query(
211+
UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE, parameters_=parameters
212+
)
213+
else:
214+
self.driver.execute_query(UPSERT_RELATIONSHIP_QUERY, parameters_=parameters)
174215

175216
async def _async_upsert_relationships(
176217
self, rels: list[Neo4jRelationship], sem: asyncio.Semaphore
@@ -182,9 +223,15 @@ async def _async_upsert_relationships(
182223
"""
183224
async with sem:
184225
parameters = {"rows": [rel.model_dump() for rel in rels]}
185-
await self.driver.execute_query(
186-
UPSERT_RELATIONSHIP_QUERY, parameters_=parameters
187-
)
226+
if self.is_version_5_23_or_above:
227+
await self.driver.execute_query(
228+
UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE,
229+
parameters_=parameters,
230+
)
231+
else:
232+
await self.driver.execute_query(
233+
UPSERT_RELATIONSHIP_QUERY, parameters_=parameters
234+
)
188235

189236
@validate_call
190237
async def run(self, graph: Neo4jGraph) -> KGWriterModel:
@@ -193,12 +240,6 @@ async def run(self, graph: Neo4jGraph) -> KGWriterModel:
193240
Args:
194241
graph (Neo4jGraph): The knowledge graph to upsert into the database.
195242
"""
196-
# we disable the notification logger to get rid of the deprecation
197-
# warning about Cypher subqueries. Once the queries are updated
198-
# for Neo4j 5.23, we can remove this line and the 'finally' block
199-
notification_logger = logging.getLogger("neo4j.notifications")
200-
notification_level = notification_logger.level
201-
notification_logger.setLevel(logging.ERROR)
202243
try:
203244
if inspect.iscoroutinefunction(self.driver.execute_query):
204245
await self._async_db_setup()
@@ -233,5 +274,3 @@ async def run(self, graph: Neo4jGraph) -> KGWriterModel:
233274
except neo4j.exceptions.ClientError as e:
234275
logger.exception(e)
235276
return KGWriterModel(status="FAILURE", metadata={"error": str(e)})
236-
finally:
237-
notification_logger.setLevel(notification_level)

src/neo4j_graphrag/neo4j_queries.py

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,20 @@
5555
"RETURN elementId(n)"
5656
)
5757

58+
UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE = (
59+
"UNWIND $rows AS row "
60+
"CREATE (n:__KGBuilder__ {id: row.id}) "
61+
"SET n += row.properties "
62+
"WITH n, row CALL apoc.create.addLabels(n, row.labels) YIELD node "
63+
"WITH node as n, row CALL (n, row) { "
64+
"WITH n, row WITH n, row WHERE row.embedding_properties IS NOT NULL "
65+
"UNWIND keys(row.embedding_properties) as emb "
66+
"CALL db.create.setNodeVectorProperty(n, emb, row.embedding_properties[emb]) "
67+
"RETURN count(*) as nbEmb "
68+
"} "
69+
"RETURN elementId(n)"
70+
)
71+
5872
UPSERT_RELATIONSHIP_QUERY = (
5973
"UNWIND $rows as row "
6074
"MATCH (start:__KGBuilder__ {id: row.start_node_id}) "
@@ -69,6 +83,21 @@
6983
"RETURN elementId(rel)"
7084
)
7185

86+
87+
UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE = (
88+
"UNWIND $rows as row "
89+
"MATCH (start:__KGBuilder__ {id: row.start_node_id}) "
90+
"MATCH (end:__KGBuilder__ {id: row.end_node_id}) "
91+
"WITH start, end, row "
92+
"CALL apoc.merge.relationship(start, row.type, {}, row.properties, end, row.properties) YIELD rel "
93+
"WITH rel, row CALL (rel, row) { "
94+
"WITH rel, row WITH rel, row WHERE row.embedding_properties IS NOT NULL "
95+
"UNWIND keys(row.embedding_properties) as emb "
96+
"CALL db.create.setRelationshipVectorProperty(rel, emb, row.embedding_properties[emb]) "
97+
"} "
98+
"RETURN elementId(rel)"
99+
)
100+
72101
UPSERT_VECTOR_ON_NODE_QUERY = (
73102
"MATCH (n) "
74103
"WHERE elementId(n) = $id "
@@ -86,19 +115,33 @@
86115
)
87116

88117

89-
def _get_hybrid_query() -> str:
90-
return (
91-
f"CALL {{ {VECTOR_INDEX_QUERY} "
92-
f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS vector_index_max_score "
93-
f"UNWIND nodes AS n "
94-
f"RETURN n.node AS node, (n.score / vector_index_max_score) AS score "
95-
f"UNION "
96-
f"{FULL_TEXT_SEARCH_QUERY} "
97-
f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS ft_index_max_score "
98-
f"UNWIND nodes AS n "
99-
f"RETURN n.node AS node, (n.score / ft_index_max_score) AS score }} "
100-
f"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k"
101-
)
118+
def _get_hybrid_query(neo4j_version_is_5_23_or_above: bool) -> str:
119+
if neo4j_version_is_5_23_or_above:
120+
return (
121+
f"CALL () {{ {VECTOR_INDEX_QUERY} "
122+
f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS vector_index_max_score "
123+
f"UNWIND nodes AS n "
124+
f"RETURN n.node AS node, (n.score / vector_index_max_score) AS score "
125+
f"UNION "
126+
f"{FULL_TEXT_SEARCH_QUERY} "
127+
f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS ft_index_max_score "
128+
f"UNWIND nodes AS n "
129+
f"RETURN n.node AS node, (n.score / ft_index_max_score) AS score }} "
130+
f"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k"
131+
)
132+
else:
133+
return (
134+
f"CALL {{ {VECTOR_INDEX_QUERY} "
135+
f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS vector_index_max_score "
136+
f"UNWIND nodes AS n "
137+
f"RETURN n.node AS node, (n.score / vector_index_max_score) AS score "
138+
f"UNION "
139+
f"{FULL_TEXT_SEARCH_QUERY} "
140+
f"WITH collect({{node:node, score:score}}) AS nodes, max(score) AS ft_index_max_score "
141+
f"UNWIND nodes AS n "
142+
f"RETURN n.node AS node, (n.score / ft_index_max_score) AS score }} "
143+
f"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k"
144+
)
102145

103146

104147
def _get_filtered_vector_query(
@@ -139,6 +182,7 @@ def get_search_query(
139182
embedding_node_property: Optional[str] = None,
140183
embedding_dimension: Optional[int] = None,
141184
filters: Optional[dict[str, Any]] = None,
185+
neo4j_version_is_5_23_or_above: bool = False,
142186
) -> tuple[str, dict[str, Any]]:
143187
"""Build the search query, including pre-filtering if needed, and return clause.
144188
@@ -160,7 +204,7 @@ def get_search_query(
160204
if search_type == SearchType.HYBRID:
161205
if filters:
162206
raise Exception("Filters are not supported with Hybrid Search")
163-
query = _get_hybrid_query()
207+
query = _get_hybrid_query(neo4j_version_is_5_23_or_above)
164208
params: dict[str, Any] = {}
165209
elif search_type == SearchType.VECTOR:
166210
if filters:

src/neo4j_graphrag/retrievers/base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,14 @@ def _get_version(self) -> tuple[tuple[int, ...], bool]:
101101
version_tuple = (*version_tuple, 0)
102102
return version_tuple, "aura" in version
103103

104+
def _check_if_version_5_23_or_above(self, version_tuple: tuple[int, ...]) -> bool:
105+
"""
106+
Check if the connected Neo4j database version supports the required features.
107+
108+
Sets a flag if the connected Neo4j version is 5.23 or above.
109+
"""
110+
return version_tuple >= (5, 23, 0)
111+
104112
def _verify_version(self) -> None:
105113
"""
106114
Check if the connected Neo4j database version supports vector indexing.
@@ -111,6 +119,9 @@ def _verify_version(self) -> None:
111119
not supported.
112120
"""
113121
version_tuple, is_aura = self._get_version()
122+
self.neo4j_version_is_5_23_or_above = self._check_if_version_5_23_or_above(
123+
version_tuple
124+
)
114125

115126
if is_aura:
116127
target_version = (5, 18, 0)

src/neo4j_graphrag/retrievers/external/pinecone/types.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,7 @@
1717
from typing import Any, Callable, Optional, Union
1818

1919
import neo4j
20-
21-
2220
from pinecone import Pinecone
23-
2421
from pydantic import (
2522
BaseModel,
2623
ConfigDict,

src/neo4j_graphrag/retrievers/external/weaviate/weaviate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import annotations
1616

1717
import logging
18-
from typing import Any, Callable, Optional, TYPE_CHECKING
18+
from typing import TYPE_CHECKING, Any, Callable, Optional
1919

2020
import neo4j
2121
import weaviate.classes as wvc

src/neo4j_graphrag/retrievers/hybrid.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,11 @@ def get_search_results(
184184
query_vector = self.embedder.embed_query(query_text)
185185
parameters["query_vector"] = query_vector
186186

187-
search_query, _ = get_search_query(SearchType.HYBRID, self.return_properties)
187+
search_query, _ = get_search_query(
188+
SearchType.HYBRID,
189+
self.return_properties,
190+
neo4j_version_is_5_23_or_above=self.neo4j_version_is_5_23_or_above,
191+
)
188192

189193
logger.debug("HybridRetriever Cypher parameters: %s", parameters)
190194
logger.debug("HybridRetriever Cypher query: %s", search_query)
@@ -336,7 +340,9 @@ def get_search_results(
336340
del parameters["query_params"]
337341

338342
search_query, _ = get_search_query(
339-
SearchType.HYBRID, retrieval_query=self.retrieval_query
343+
SearchType.HYBRID,
344+
retrieval_query=self.retrieval_query,
345+
neo4j_version_is_5_23_or_above=self.neo4j_version_is_5_23_or_above,
340346
)
341347

342348
logger.debug("HybridCypherRetriever Cypher parameters: %s", parameters)

tests/e2e/docker-compose.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ services:
2626
environment:
2727
ENABLE_CUDA: "0"
2828
neo4j:
29-
image: neo4j:5-enterprise
29+
image: neo4j:5.24-enterprise
3030
ports:
3131
- 7687:7687
3232
- 7474:7474

0 commit comments

Comments
 (0)