Skip to content

Commit 02ad6f8

Browse files
committed
Override neo4j user agent when driver is injected
1 parent d7d6674 commit 02ad6f8

File tree

7 files changed

+38
-5
lines changed

7 files changed

+38
-5
lines changed

examples/customize/build_graph/components/writers/custom_writer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
from neo4j_graphrag.experimental.components.kg_writer import KGWriter, KGWriterModel
77
from neo4j_graphrag.experimental.components.types import LexicalGraphConfig, Neo4jGraph
88
from pydantic import validate_call
9+
from neo4j_graphrag.utils import driver_config
910

1011

1112
class MyWriter(KGWriter):
1213
def __init__(self, driver: neo4j.Driver) -> None:
13-
self.driver = driver
14+
self.driver = driver_config.override_user_agent(driver)
1415

1516
@validate_call
1617
async def run(

src/neo4j_graphrag/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,9 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
from importlib.metadata import version, PackageNotFoundError
16+
17+
try:
18+
__version__ = version("neo4j-graphrag")
19+
except PackageNotFoundError:
20+
__version__ = "0.0.0"

src/neo4j_graphrag/experimental/components/kg_writer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
get_version,
3939
is_version_5_23_or_above,
4040
)
41+
from neo4j_graphrag.utils import driver_config
4142

4243
logger = logging.getLogger(__name__)
4344

@@ -117,7 +118,7 @@ def __init__(
117118
neo4j_database: Optional[str] = None,
118119
batch_size: int = 1000,
119120
):
120-
self.driver = driver
121+
self.driver = driver_config.override_user_agent(driver)
121122
self.neo4j_database = neo4j_database
122123
self.batch_size = batch_size
123124
version_tuple, _, _ = get_version(self.driver, self.neo4j_database)

src/neo4j_graphrag/experimental/components/neo4j_reader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
TextChunks,
2626
)
2727
from neo4j_graphrag.experimental.pipeline import Component
28+
from neo4j_graphrag.utils import driver_config
2829

2930

3031
class Neo4jChunkReader(Component):
@@ -58,7 +59,7 @@ def __init__(
5859
fetch_embeddings: bool = False,
5960
neo4j_database: Optional[str] = None,
6061
):
61-
self.driver = driver
62+
self.driver = driver_config.override_user_agent(driver)
6263
self.fetch_embeddings = fetch_embeddings
6364
self.neo4j_database = neo4j_database
6465

src/neo4j_graphrag/experimental/components/resolver.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from neo4j_graphrag.experimental.components.types import ResolutionStats
2121
from neo4j_graphrag.experimental.pipeline import Component
22+
from neo4j_graphrag.utils import driver_config
2223

2324

2425
class EntityResolver(Component, abc.ABC):
@@ -34,7 +35,7 @@ def __init__(
3435
driver: neo4j.Driver,
3536
filter_query: Optional[str] = None,
3637
) -> None:
37-
self.driver = driver
38+
self.driver = driver_config.override_user_agent(driver)
3839
self.filter_query = filter_query
3940

4041
@abc.abstractmethod

src/neo4j_graphrag/retrievers/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
has_vector_index_support,
3131
is_version_5_23_or_above,
3232
)
33+
from neo4j_graphrag.utils import driver_config
3334

3435
T = ParamSpec("T")
3536
P = TypeVar("P")
@@ -88,7 +89,7 @@ class Retriever(ABC, metaclass=RetrieverMetaclass):
8889
VERIFY_NEO4J_VERSION = True
8990

9091
def __init__(self, driver: neo4j.Driver, neo4j_database: Optional[str] = None):
91-
self.driver = driver
92+
self.driver = driver_config.override_user_agent(driver)
9293
self.neo4j_database = neo4j_database
9394
if self.VERIFY_NEO4J_VERSION:
9495
version_tuple, is_aura, _ = get_version(self.driver, self.neo4j_database)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
# #
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import neo4j
16+
from neo4j_graphrag import __version__
17+
18+
19+
# Override user-agent used by neo4j package so we can measure usage of the package by version
20+
def override_user_agent(driver: neo4j.Driver) -> neo4j.Driver:
21+
driver._pool.pool_config.user_agent = f"neo4j-graphrag-python/v{__version__}"
22+
return driver

0 commit comments

Comments
 (0)