Skip to content

Commit 9dbe345

Browse files
authored
Make Neo4j version check work with trailing "-XXX" (#162)
* Fix version verification * CI * Ruff
1 parent 53bb4f1 commit 9dbe345

File tree

2 files changed

+51
-16
lines changed

2 files changed

+51
-16
lines changed

src/neo4j_graphrag/retrievers/base.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,20 @@ def __init__(self, driver: neo4j.Driver, neo4j_database: Optional[str] = None):
8787
if self.VERIFY_NEO4J_VERSION:
8888
self._verify_version()
8989

90+
def _get_version(self) -> tuple[tuple[int, ...], bool]:
91+
records, _, _ = self.driver.execute_query(
92+
"CALL dbms.components()", database_=self.neo4j_database
93+
)
94+
version = records[0]["versions"][0]
95+
# drop everything after the '-' first
96+
version_main, *_ = version.split("-")
97+
# convert each number between '.' into int
98+
version_tuple = tuple(map(int, version_main.split(".")))
99+
# if no patch version, consider it's 0
100+
if len(version_tuple) < 3:
101+
version_tuple = (*version_tuple, 0)
102+
return version_tuple, "aura" in version
103+
90104
def _verify_version(self) -> None:
91105
"""
92106
Check if the connected Neo4j database version supports vector indexing.
@@ -96,19 +110,11 @@ def _verify_version(self) -> None:
96110
indexing. Raises a Neo4jMinVersionError if the connected Neo4j version is
97111
not supported.
98112
"""
99-
records, _, _ = self.driver.execute_query(
100-
"CALL dbms.components()", database_=self.neo4j_database
101-
)
102-
version = records[0]["versions"][0]
113+
version_tuple, is_aura = self._get_version()
103114

104-
if "aura" in version:
105-
version_tuple = (
106-
*tuple(map(int, version.split("-")[0].split("."))),
107-
0,
108-
)
115+
if is_aura:
109116
target_version = (5, 18, 0)
110117
else:
111-
version_tuple = tuple(map(int, version.split(".")))
112118
target_version = (5, 18, 1)
113119

114120
if version_tuple < target_version:

tests/unit/retrievers/test_base.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,25 +24,54 @@
2424
from neo4j_graphrag.types import RawSearchResult, RetrieverResult
2525

2626

27+
@pytest.mark.parametrize(
28+
"db_version,expected_version",
29+
[
30+
(["5.18-aura"], ((5, 18, 0), True)),
31+
(["5.3-aura"], ((5, 3, 0), True)),
32+
(["5.19.0"], ((5, 19, 0), False)),
33+
(["4.3.5"], ((4, 3, 5), False)),
34+
(["5.23.0-6698"], ((5, 23, 0), False)),
35+
],
36+
)
37+
def test_retriever_get_version(
38+
driver: MagicMock,
39+
db_version: list[str],
40+
expected_version: tuple[tuple[int, ...], bool],
41+
) -> None:
42+
class MockRetriever(Retriever):
43+
VERIFY_NEO4J_VERSION = False
44+
45+
def get_search_results(self, *args: Any, **kwargs: Any) -> RawSearchResult:
46+
return RawSearchResult(records=[])
47+
48+
driver.execute_query.return_value = [[{"versions": db_version}], None, None]
49+
retriever = MockRetriever(driver)
50+
assert retriever._get_version() == expected_version
51+
52+
2753
@pytest.mark.parametrize(
2854
"db_version,expected_exception",
2955
[
30-
(["5.18-aura"], None),
31-
(["5.3-aura"], Neo4jVersionError),
32-
(["5.19.0"], None),
33-
(["4.3.5"], Neo4jVersionError),
56+
(((5, 18, 0), True), None),
57+
(((5, 3, 0), True), Neo4jVersionError),
58+
(((5, 19, 0), False), None),
59+
(((4, 3, 5), False), Neo4jVersionError),
60+
(((5, 23, 0), False), None),
3461
],
3562
)
63+
@patch("neo4j_graphrag.retrievers.base.Retriever._get_version")
3664
def test_retriever_version_support(
65+
mock_get_version: MagicMock,
3766
driver: MagicMock,
38-
db_version: list[str],
67+
db_version: tuple[tuple[int, ...], bool],
3968
expected_exception: Union[type[ValueError], None],
4069
) -> None:
4170
class MockRetriever(Retriever):
4271
def get_search_results(self, *args: Any, **kwargs: Any) -> RawSearchResult:
4372
return RawSearchResult(records=[])
4473

45-
driver.execute_query.return_value = [[{"versions": db_version}], None, None]
74+
mock_get_version.return_value = db_version
4675
if expected_exception:
4776
with pytest.raises(expected_exception):
4877
MockRetriever(driver=driver)

0 commit comments

Comments
 (0)