Skip to content

Override neo4j user agent when driver is injected #243

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from neo4j_graphrag.experimental.components.kg_writer import KGWriter, KGWriterModel
from neo4j_graphrag.experimental.components.types import LexicalGraphConfig, Neo4jGraph
from pydantic import validate_call
from neo4j_graphrag.utils import driver_config


class MyWriter(KGWriter):
def __init__(self, driver: neo4j.Driver) -> None:
self.driver = driver
self.driver = driver_config.override_user_agent(driver)

@validate_call
async def run(
Expand Down
6 changes: 6 additions & 0 deletions src/neo4j_graphrag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from importlib.metadata import version, PackageNotFoundError

try:
__version__ = version("neo4j-graphrag")
except PackageNotFoundError:
__version__ = "0.0.0"
3 changes: 2 additions & 1 deletion src/neo4j_graphrag/experimental/components/kg_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
get_version,
is_version_5_23_or_above,
)
from neo4j_graphrag.utils import driver_config

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -117,7 +118,7 @@ def __init__(
neo4j_database: Optional[str] = None,
batch_size: int = 1000,
):
self.driver = driver
self.driver = driver_config.override_user_agent(driver)
self.neo4j_database = neo4j_database
self.batch_size = batch_size
version_tuple, _, _ = get_version(self.driver, self.neo4j_database)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
TextChunks,
)
from neo4j_graphrag.experimental.pipeline import Component
from neo4j_graphrag.utils import driver_config


class Neo4jChunkReader(Component):
Expand Down Expand Up @@ -58,7 +59,7 @@ def __init__(
fetch_embeddings: bool = False,
neo4j_database: Optional[str] = None,
):
self.driver = driver
self.driver = driver_config.override_user_agent(driver)
self.fetch_embeddings = fetch_embeddings
self.neo4j_database = neo4j_database

Expand Down
3 changes: 2 additions & 1 deletion src/neo4j_graphrag/experimental/components/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from neo4j_graphrag.experimental.components.types import ResolutionStats
from neo4j_graphrag.experimental.pipeline import Component
from neo4j_graphrag.utils import driver_config


class EntityResolver(Component, abc.ABC):
Expand All @@ -34,7 +35,7 @@ def __init__(
driver: neo4j.Driver,
filter_query: Optional[str] = None,
) -> None:
self.driver = driver
self.driver = driver_config.override_user_agent(driver)
self.filter_query = filter_query

@abc.abstractmethod
Expand Down
3 changes: 2 additions & 1 deletion src/neo4j_graphrag/retrievers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
has_vector_index_support,
is_version_5_23_or_above,
)
from neo4j_graphrag.utils import driver_config

T = ParamSpec("T")
P = TypeVar("P")
Expand Down Expand Up @@ -88,7 +89,7 @@ class Retriever(ABC, metaclass=RetrieverMetaclass):
VERIFY_NEO4J_VERSION = True

def __init__(self, driver: neo4j.Driver, neo4j_database: Optional[str] = None):
self.driver = driver
self.driver = driver_config.override_user_agent(driver)
self.neo4j_database = neo4j_database
if self.VERIFY_NEO4J_VERSION:
version_tuple, is_aura, _ = get_version(self.driver, self.neo4j_database)
Expand Down
22 changes: 22 additions & 0 deletions src/neo4j_graphrag/utils/driver_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import neo4j
from neo4j_graphrag import __version__


# Override user-agent used by neo4j package so we can measure usage of the package by version
def override_user_agent(driver: neo4j.Driver) -> neo4j.Driver:
driver._pool.pool_config.user_agent = f"neo4j-graphrag-python/v{__version__}"
return driver