From 81a211bbf7ad554a92142dbf84ce98d06283caa2 Mon Sep 17 00:00:00 2001 From: Jon Besga Date: Mon, 13 Jan 2025 23:22:26 +0000 Subject: [PATCH] Override neo4j user agent when driver is injected --- .../components/writers/custom_writer.py | 3 ++- src/neo4j_graphrag/__init__.py | 6 +++++ .../experimental/components/kg_writer.py | 3 ++- .../experimental/components/neo4j_reader.py | 3 ++- .../experimental/components/resolver.py | 3 ++- src/neo4j_graphrag/retrievers/base.py | 3 ++- src/neo4j_graphrag/utils/driver_config.py | 22 +++++++++++++++++++ 7 files changed, 38 insertions(+), 5 deletions(-) create mode 100644 src/neo4j_graphrag/utils/driver_config.py diff --git a/examples/customize/build_graph/components/writers/custom_writer.py b/examples/customize/build_graph/components/writers/custom_writer.py index 2c64bd160..4b44a620f 100644 --- a/examples/customize/build_graph/components/writers/custom_writer.py +++ b/examples/customize/build_graph/components/writers/custom_writer.py @@ -6,11 +6,12 @@ from neo4j_graphrag.experimental.components.kg_writer import KGWriter, KGWriterModel from neo4j_graphrag.experimental.components.types import LexicalGraphConfig, Neo4jGraph from pydantic import validate_call +from neo4j_graphrag.utils import driver_config class MyWriter(KGWriter): def __init__(self, driver: neo4j.Driver) -> None: - self.driver = driver + self.driver = driver_config.override_user_agent(driver) @validate_call async def run( diff --git a/src/neo4j_graphrag/__init__.py b/src/neo4j_graphrag/__init__.py index c0199c144..a65a71346 100644 --- a/src/neo4j_graphrag/__init__.py +++ b/src/neo4j_graphrag/__init__.py @@ -12,3 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from importlib.metadata import version, PackageNotFoundError + +try: + __version__ = version("neo4j-graphrag") +except PackageNotFoundError: + __version__ = "0.0.0" diff --git a/src/neo4j_graphrag/experimental/components/kg_writer.py b/src/neo4j_graphrag/experimental/components/kg_writer.py index 569115d9f..4a859b36e 100644 --- a/src/neo4j_graphrag/experimental/components/kg_writer.py +++ b/src/neo4j_graphrag/experimental/components/kg_writer.py @@ -38,6 +38,7 @@ get_version, is_version_5_23_or_above, ) +from neo4j_graphrag.utils import driver_config logger = logging.getLogger(__name__) @@ -117,7 +118,7 @@ def __init__( neo4j_database: Optional[str] = None, batch_size: int = 1000, ): - self.driver = driver + self.driver = driver_config.override_user_agent(driver) self.neo4j_database = neo4j_database self.batch_size = batch_size version_tuple, _, _ = get_version(self.driver, self.neo4j_database) diff --git a/src/neo4j_graphrag/experimental/components/neo4j_reader.py b/src/neo4j_graphrag/experimental/components/neo4j_reader.py index 352ed1a6e..24d25fd3d 100644 --- a/src/neo4j_graphrag/experimental/components/neo4j_reader.py +++ b/src/neo4j_graphrag/experimental/components/neo4j_reader.py @@ -25,6 +25,7 @@ TextChunks, ) from neo4j_graphrag.experimental.pipeline import Component +from neo4j_graphrag.utils import driver_config class Neo4jChunkReader(Component): @@ -58,7 +59,7 @@ def __init__( fetch_embeddings: bool = False, neo4j_database: Optional[str] = None, ): - self.driver = driver + self.driver = driver_config.override_user_agent(driver) self.fetch_embeddings = fetch_embeddings self.neo4j_database = neo4j_database diff --git a/src/neo4j_graphrag/experimental/components/resolver.py b/src/neo4j_graphrag/experimental/components/resolver.py index f2da0bff5..a050ea35e 100644 --- a/src/neo4j_graphrag/experimental/components/resolver.py +++ b/src/neo4j_graphrag/experimental/components/resolver.py @@ -19,6 +19,7 @@ from neo4j_graphrag.experimental.components.types import ResolutionStats from neo4j_graphrag.experimental.pipeline import Component +from neo4j_graphrag.utils import driver_config class EntityResolver(Component, abc.ABC): @@ -34,7 +35,7 @@ def __init__( driver: neo4j.Driver, filter_query: Optional[str] = None, ) -> None: - self.driver = driver + self.driver = driver_config.override_user_agent(driver) self.filter_query = filter_query @abc.abstractmethod diff --git a/src/neo4j_graphrag/retrievers/base.py b/src/neo4j_graphrag/retrievers/base.py index 116ce395c..c3b295d15 100644 --- a/src/neo4j_graphrag/retrievers/base.py +++ b/src/neo4j_graphrag/retrievers/base.py @@ -30,6 +30,7 @@ has_vector_index_support, is_version_5_23_or_above, ) +from neo4j_graphrag.utils import driver_config T = ParamSpec("T") P = TypeVar("P") @@ -88,7 +89,7 @@ class Retriever(ABC, metaclass=RetrieverMetaclass): VERIFY_NEO4J_VERSION = True def __init__(self, driver: neo4j.Driver, neo4j_database: Optional[str] = None): - self.driver = driver + self.driver = driver_config.override_user_agent(driver) self.neo4j_database = neo4j_database if self.VERIFY_NEO4J_VERSION: version_tuple, is_aura, _ = get_version(self.driver, self.neo4j_database) diff --git a/src/neo4j_graphrag/utils/driver_config.py b/src/neo4j_graphrag/utils/driver_config.py new file mode 100644 index 000000000..09422d8a7 --- /dev/null +++ b/src/neo4j_graphrag/utils/driver_config.py @@ -0,0 +1,22 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import neo4j +from neo4j_graphrag import __version__ + + +# Override user-agent used by neo4j package so we can measure usage of the package by version +def override_user_agent(driver: neo4j.Driver) -> neo4j.Driver: + driver._pool.pool_config.user_agent = f"neo4j-graphrag-python/v{__version__}" + return driver