Skip to content

Commit e0b5e86

Browse files
authored
Refactors Neo4j version checking code (#255)
* Refactors Neo4j version checking code into a set of utility functions * Adds the copyright header to new files * Updated base retriever * Fixed hybrid retriever unit tests * Updated Text2Cypher unit tests * Updated vector retriever tests * Added the ability to return the edition of the db in get_version * Updated kg writer
1 parent 3b7ded4 commit e0b5e86

File tree

12 files changed

+426
-256
lines changed

12 files changed

+426
-256
lines changed

src/neo4j_graphrag/experimental/components/kg_writer.py

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@
3434
UPSERT_RELATIONSHIP_QUERY,
3535
UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE,
3636
)
37+
from neo4j_graphrag.utils.version_utils import (
38+
get_version,
39+
is_version_5_23_or_above,
40+
)
3741

3842
logger = logging.getLogger(__name__)
3943

@@ -116,7 +120,8 @@ def __init__(
116120
self.driver = driver
117121
self.neo4j_database = neo4j_database
118122
self.batch_size = batch_size
119-
self.is_version_5_23_or_above = self._check_if_version_5_23_or_above()
123+
version_tuple, _, _ = get_version(self.driver, self.neo4j_database)
124+
self.is_version_5_23_or_above = is_version_5_23_or_above(version_tuple)
120125

121126
def _db_setup(self) -> None:
122127
# create index on __KGBuilder__.id
@@ -162,29 +167,6 @@ def _upsert_nodes(
162167
database_=self.neo4j_database,
163168
)
164169

165-
def _get_version(self) -> tuple[int, ...]:
166-
records, _, _ = self.driver.execute_query(
167-
"CALL dbms.components()", database_=self.neo4j_database
168-
)
169-
version = records[0]["versions"][0]
170-
# Drop everything after the '-' first
171-
version_main, *_ = version.split("-")
172-
# Convert each number between '.' into int
173-
version_tuple = tuple(map(int, version_main.split(".")))
174-
# If no patch version, consider it's 0
175-
if len(version_tuple) < 3:
176-
version_tuple = (*version_tuple, 0)
177-
return version_tuple
178-
179-
def _check_if_version_5_23_or_above(self) -> bool:
180-
"""
181-
Check if the connected Neo4j database version supports the required features.
182-
183-
Sets a flag if the connected Neo4j version is 5.23 or above.
184-
"""
185-
version_tuple = self._get_version()
186-
return version_tuple >= (5, 23, 0)
187-
188170
def _upsert_relationships(self, rels: list[Neo4jRelationship]) -> None:
189171
"""Upserts a single relationship into the Neo4j database.
190172

src/neo4j_graphrag/retrievers/base.py

Lines changed: 14 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@
2424

2525
from neo4j_graphrag.exceptions import Neo4jVersionError
2626
from neo4j_graphrag.types import RawSearchResult, RetrieverResult, RetrieverResultItem
27+
from neo4j_graphrag.utils.version_utils import (
28+
get_version,
29+
has_metadata_filtering_support,
30+
has_vector_index_support,
31+
is_version_5_23_or_above,
32+
)
2733

2834
T = ParamSpec("T")
2935
P = TypeVar("P")
@@ -85,53 +91,14 @@ def __init__(self, driver: neo4j.Driver, neo4j_database: Optional[str] = None):
8591
self.driver = driver
8692
self.neo4j_database = neo4j_database
8793
if self.VERIFY_NEO4J_VERSION:
88-
self._verify_version()
89-
90-
def _get_version(self) -> tuple[tuple[int, ...], bool]:
91-
records, _, _ = self.driver.execute_query(
92-
"CALL dbms.components()",
93-
database_=self.neo4j_database,
94-
routing_=neo4j.RoutingControl.READ,
95-
)
96-
version = records[0]["versions"][0]
97-
# drop everything after the '-' first
98-
version_main, *_ = version.split("-")
99-
# convert each number between '.' into int
100-
version_tuple = tuple(map(int, version_main.split(".")))
101-
# if no patch version, consider it's 0
102-
if len(version_tuple) < 3:
103-
version_tuple = (*version_tuple, 0)
104-
return version_tuple, "aura" in version
105-
106-
def _check_if_version_5_23_or_above(self, version_tuple: tuple[int, ...]) -> bool:
107-
"""
108-
Check if the connected Neo4j database version supports the required features.
109-
110-
Sets a flag if the connected Neo4j version is 5.23 or above.
111-
"""
112-
return version_tuple >= (5, 23, 0)
113-
114-
def _verify_version(self) -> None:
115-
"""
116-
Check if the connected Neo4j database version supports vector indexing.
117-
118-
Queries the Neo4j database to retrieve its version and compares it
119-
against a target version (5.18.1) that is known to support vector
120-
indexing. Raises a Neo4jMinVersionError if the connected Neo4j version is
121-
not supported.
122-
"""
123-
version_tuple, is_aura = self._get_version()
124-
self.neo4j_version_is_5_23_or_above = self._check_if_version_5_23_or_above(
125-
version_tuple
126-
)
127-
128-
if is_aura:
129-
target_version = (5, 18, 0)
130-
else:
131-
target_version = (5, 18, 1)
132-
133-
if version_tuple < target_version:
134-
raise Neo4jVersionError()
94+
version_tuple, is_aura, _ = get_version(self.driver, self.neo4j_database)
95+
self.neo4j_version_is_5_23_or_above = is_version_5_23_or_above(
96+
version_tuple
97+
)
98+
if not has_vector_index_support(
99+
version_tuple
100+
) or not has_metadata_filtering_support(version_tuple, is_aura):
101+
raise Neo4jVersionError()
135102

136103
def _fetch_index_infos(self, vector_index_name: str) -> None:
137104
"""Fetch the node label and embedding property from the index definition
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
# #
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from typing import Optional
16+
17+
import neo4j
18+
19+
20+
def get_version(
21+
driver: neo4j.Driver, database: Optional[str] = None
22+
) -> tuple[tuple[int, ...], bool, bool]:
23+
"""
24+
Retrieves the Neo4j database version and checks if it is running on the Aura platform.
25+
26+
Args:
27+
driver (neo4j.Driver): Neo4j Python driver instance to execute the query.
28+
database (str, optional): The name of the Neo4j database to query. Defaults to None.
29+
30+
Returns:
31+
tuple[tuple[int, ...], bool]:
32+
- A tuple of integers representing the database version (major, minor, patch) or
33+
(year, month, patch) for later versions.
34+
- A boolean indicating whether the database is hosted on the Aura platform.
35+
- A boolean indicating whether the database is running the enterprise edition.
36+
"""
37+
records, _, _ = driver.execute_query(
38+
"CALL dbms.components()",
39+
database_=database,
40+
routing_=neo4j.RoutingControl.READ,
41+
)
42+
version = records[0]["versions"][0]
43+
edition = records[0]["edition"]
44+
# drop everything after the '-' first
45+
version_main, *_ = version.split("-")
46+
# convert each number between '.' into int
47+
version_tuple = tuple(map(int, version_main.split(".")))
48+
# if no patch version, consider it's 0
49+
if len(version_tuple) < 3:
50+
version_tuple = (*version_tuple, 0)
51+
return version_tuple, "aura" in version, edition == "enterprise"
52+
53+
54+
def is_version_5_23_or_above(version_tuple: tuple[int, ...]) -> bool:
55+
"""
56+
Determines if the Neo4j database version is 5.23 or above.
57+
58+
Args:
59+
version_tuple (tuple[int, ...]): A tuple of integers representing the database version
60+
(major, minor, patch) or (year, month, patch) for later versions.
61+
62+
Returns:
63+
bool: True if the version is 5.23.0 or above, False otherwise.
64+
"""
65+
return version_tuple >= (5, 23, 0)
66+
67+
68+
def has_vector_index_support(version_tuple: tuple[int, ...]) -> bool:
69+
"""
70+
Checks if a Neo4j database supports vector indexing based on its version and platform.
71+
72+
Args:
73+
version_tuple (neo4j.Driver): A tuple of integers representing the database version (major, minor, patch) or
74+
(year, month, patch) for later versions.
75+
76+
Returns:
77+
bool: True if the connected Neo4j database version supports vector indexing, False otherwise.
78+
"""
79+
return version_tuple >= (5, 11, 0)
80+
81+
82+
def has_metadata_filtering_support(
83+
version_tuple: tuple[int, ...], is_aura: bool
84+
) -> bool:
85+
"""
86+
Checks if a Neo4j database supports vector index metadata filtering based on its version and platform.
87+
88+
Args:
89+
version_tuple (neo4j.Driver): A tuple of integers representing the database version (major, minor, patch) or
90+
(year, month, patch) for later versions.
91+
is_aura (bool): A boolean indicating whether the database is hosted on the Aura platform.
92+
93+
Returns:
94+
bool: True if the connected Neo4j database version supports vector index metadata filtering , False otherwise.
95+
"""
96+
if is_aura:
97+
target_version = (5, 18, 0)
98+
else:
99+
target_version = (5, 18, 1)
100+
101+
return version_tuple >= target_version

tests/unit/conftest.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,37 +51,35 @@ def retriever_mock() -> MagicMock:
5151

5252

5353
@pytest.fixture(scope="function")
54-
@patch("neo4j_graphrag.retrievers.VectorRetriever._verify_version")
55-
def vector_retriever(
56-
_verify_version_mock: MagicMock, driver: MagicMock
57-
) -> VectorRetriever:
54+
@patch("neo4j_graphrag.retrievers.base.get_version")
55+
def vector_retriever(mock_get_version: MagicMock, driver: MagicMock) -> VectorRetriever:
56+
mock_get_version.return_value = ((5, 23, 0), False, False)
5857
return VectorRetriever(driver, "my-index")
5958

6059

6160
@pytest.fixture(scope="function")
62-
@patch("neo4j_graphrag.retrievers.VectorCypherRetriever._verify_version")
61+
@patch("neo4j_graphrag.retrievers.base.get_version")
6362
def vector_cypher_retriever(
64-
_verify_version_mock: MagicMock, driver: MagicMock
63+
mock_get_version: MagicMock, driver: MagicMock
6564
) -> VectorCypherRetriever:
66-
retrieval_query = """
67-
RETURN node.id AS node_id, node.text AS text, score
68-
"""
65+
mock_get_version.return_value = ((5, 23, 0), False, False)
66+
retrieval_query = "RETURN node.id AS node_id, node.text AS text, score"
6967
return VectorCypherRetriever(driver, "my-index", retrieval_query)
7068

7169

7270
@pytest.fixture(scope="function")
73-
@patch("neo4j_graphrag.retrievers.HybridRetriever._verify_version")
74-
def hybrid_retriever(
75-
_verify_version_mock: MagicMock, driver: MagicMock
76-
) -> HybridRetriever:
71+
@patch("neo4j_graphrag.retrievers.base.get_version")
72+
def hybrid_retriever(mock_get_version: MagicMock, driver: MagicMock) -> HybridRetriever:
73+
mock_get_version.return_value = ((5, 23, 0), False, False)
7774
return HybridRetriever(driver, "my-index", "my-fulltext-index")
7875

7976

8077
@pytest.fixture(scope="function")
81-
@patch("neo4j_graphrag.retrievers.Text2CypherRetriever._verify_version")
78+
@patch("neo4j_graphrag.retrievers.base.get_version")
8279
def t2c_retriever(
83-
_verify_version_mock: MagicMock, driver: MagicMock, llm: MagicMock
80+
mock_get_version: MagicMock, driver: MagicMock, llm: MagicMock
8481
) -> Text2CypherRetriever:
82+
mock_get_version.return_value = ((5, 23, 0), False, False)
8583
return Text2CypherRetriever(driver, llm)
8684

8785

0 commit comments

Comments
 (0)