Skip to content

Commit f454e50

Browse files
willtaijonbesga
andauthored
Allow specifying database name for retrievers and index handling methods (neo4j#71)
* Include database name in driver.execute_query * Update tests * Added configurable database names to the Pinecone and Weaviate retrievers * Fix typing for filtered vector queries (neo4j#76) * Rename database to neo4j_database --------- Co-authored-by: Jon Besga <jon.besga@neo4j.com>
1 parent 09975ce commit f454e50

File tree

15 files changed

+238
-24
lines changed

15 files changed

+238
-24
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
- Added LLMInterface with implementation for OpenAI LLM.
1313
- Updated project configuration to support multiple Python versions (3.8 to 3.12) in CI workflows.
1414
- Improved developer experience by copying the docstring from the `Retriever.get_search_results` method to the `Retriever.search` method
15+
- Support for specifying database names in index handling methods and retrievers.
1516

1617
### Changed
1718
- Refactored import paths for retrievers to neo4j_genai.retrievers.
1819
- Implemented exception chaining for all re-raised exceptions to improve stack trace readability.
1920
- Made error messages in `index.py` more consistent.
2021
- Renamed `Retriever._get_search_results` to `Retriever.get_search_results`
22+
- Updated retrievers and index handling methods to accept optional database names.
2123

2224
## 0.2.0
2325

src/neo4j_genai/indexes.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from __future__ import annotations
1616

1717
import logging
18-
from typing import Literal
18+
from typing import Literal, Optional
1919

2020
import neo4j
2121
from pydantic import ValidationError
@@ -33,6 +33,7 @@ def create_vector_index(
3333
embedding_property: str,
3434
dimensions: int,
3535
similarity_fn: Literal["euclidean", "cosine"],
36+
neo4j_database: Optional[str] = None,
3637
) -> None:
3738
"""
3839
This method constructs a Cypher query and executes it
@@ -77,6 +78,7 @@ def create_vector_index(
7778
dimensions (int): Vector embedding dimension
7879
similarity_fn (str): case-insensitive values for the vector similarity function:
7980
``euclidean`` or ``cosine``.
81+
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
8082
8183
Raises:
8284
ValueError: If validation of the input arguments fail.
@@ -105,13 +107,18 @@ def create_vector_index(
105107
driver.execute_query(
106108
query,
107109
{"name": name, "dimensions": dimensions, "similarity_fn": similarity_fn},
110+
database_=neo4j_database,
108111
)
109112
except neo4j.exceptions.ClientError as e:
110113
raise Neo4jIndexError(f"Neo4j vector index creation failed: {e.message}") from e
111114

112115

113116
def create_fulltext_index(
114-
driver: neo4j.Driver, name: str, label: str, node_properties: list[str]
117+
driver: neo4j.Driver,
118+
name: str,
119+
label: str,
120+
node_properties: list[str],
121+
neo4j_database: Optional[str] = None,
115122
) -> None:
116123
"""
117124
This method constructs a Cypher query and executes it
@@ -151,6 +158,7 @@ def create_fulltext_index(
151158
name (str): The unique name of the index.
152159
label (str): The node label to be indexed.
153160
node_properties (list[str]): The node properties to create the fulltext index on.
161+
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
154162
155163
Raises:
156164
ValueError: If validation of the input arguments fail.
@@ -172,14 +180,16 @@ def create_fulltext_index(
172180
f"[{', '.join(['n.`' + prop + '`' for prop in node_properties])}]"
173181
)
174182
logger.info(f"Creating fulltext index named '{name}'")
175-
driver.execute_query(query, {"name": name})
183+
driver.execute_query(query, {"name": name}, database_=neo4j_database)
176184
except neo4j.exceptions.ClientError as e:
177185
raise Neo4jIndexError(
178186
f"Neo4j fulltext index creation failed {e.message}"
179187
) from e
180188

181189

182-
def drop_index_if_exists(driver: neo4j.Driver, name: str) -> None:
190+
def drop_index_if_exists(
191+
driver: neo4j.Driver, name: str, neo4j_database: Optional[str] = None
192+
) -> None:
183193
"""
184194
This method constructs a Cypher query and executes it
185195
to drop an index in Neo4j, if the index exists.
@@ -210,6 +220,7 @@ def drop_index_if_exists(driver: neo4j.Driver, name: str) -> None:
210220
Args:
211221
driver (neo4j.Driver): Neo4j Python driver instance.
212222
name (str): The name of the index to delete.
223+
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
213224
214225
Raises:
215226
neo4j.exceptions.ClientError: If dropping of index fails.
@@ -220,7 +231,7 @@ def drop_index_if_exists(driver: neo4j.Driver, name: str) -> None:
220231
"name": name,
221232
}
222233
logger.info(f"Dropping index named '{name}'")
223-
driver.execute_query(query, parameters)
234+
driver.execute_query(query, parameters, database_=neo4j_database)
224235
except neo4j.exceptions.ClientError as e:
225236
raise Neo4jIndexError(f"Dropping Neo4j index failed: {e.message}") from e
226237

@@ -230,6 +241,7 @@ def upsert_vector(
230241
node_id: int,
231242
embedding_property: str,
232243
vector: list[float],
244+
neo4j_database: Optional[str] = None,
233245
) -> None:
234246
"""
235247
This method constructs a Cypher query and executes it to upsert (insert or update) a vector property on a specific node.
@@ -260,6 +272,7 @@ def upsert_vector(
260272
node_id (int): The id of the node.
261273
embedding_property (str): The name of the property to store the vector in.
262274
vector (list[float]): The vector to store.
275+
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
263276
264277
Raises:
265278
Neo4jInsertionError: If upserting of the vector fails.
@@ -277,7 +290,7 @@ def upsert_vector(
277290
"embedding_property": embedding_property,
278291
"vector": vector,
279292
}
280-
driver.execute_query(query, parameters)
293+
driver.execute_query(query, parameters, database_=neo4j_database)
281294
except neo4j.exceptions.ClientError as e:
282295
raise Neo4jInsertionError(
283296
f"Upserting vector to Neo4j failed: {e.message}"

src/neo4j_genai/retrievers/base.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,9 @@ class Retriever(ABC, metaclass=RetrieverMetaclass):
8181
index_name: str
8282
VERIFY_NEO4J_VERSION = True
8383

84-
def __init__(self, driver: neo4j.Driver):
84+
def __init__(self, driver: neo4j.Driver, neo4j_database: Optional[str] = None):
8585
self.driver = driver
86+
self.neo4j_database = neo4j_database
8687
if self.VERIFY_NEO4J_VERSION:
8788
self._verify_version()
8889

@@ -95,7 +96,9 @@ def _verify_version(self) -> None:
9596
indexing. Raises a Neo4jMinVersionError if the connected Neo4j version is
9697
not supported.
9798
"""
98-
records, _, _ = self.driver.execute_query("CALL dbms.components()")
99+
records, _, _ = self.driver.execute_query(
100+
"CALL dbms.components()", database_=self.neo4j_database
101+
)
99102
version = records[0]["versions"][0]
100103

101104
if "aura" in version:
@@ -120,7 +123,9 @@ def _fetch_index_infos(self) -> None:
120123
"RETURN labelsOrTypes as labels, properties, "
121124
"options.indexConfig.`vector.dimensions` as dimensions"
122125
)
123-
query_result = self.driver.execute_query(query, {"index_name": self.index_name})
126+
query_result = self.driver.execute_query(
127+
query, {"index_name": self.index_name}, database_=self.neo4j_database
128+
)
124129
try:
125130
result = query_result.records[0]
126131
self._node_label = result["labels"][0]
@@ -185,11 +190,16 @@ class ExternalRetriever(Retriever, ABC):
185190
VERIFY_NEO4J_VERSION = False
186191

187192
def __init__(
188-
self, driver: neo4j.Driver, id_property_external: str, id_property_neo4j: str
193+
self,
194+
driver: neo4j.Driver,
195+
id_property_external: str,
196+
id_property_neo4j: str,
197+
neo4j_database: Optional[str] = None,
189198
):
190199
super().__init__(driver)
191200
self.id_property_external = id_property_external
192201
self.id_property_neo4j = id_property_neo4j
202+
self.neo4j_database = neo4j_database
193203

194204
@abstractmethod
195205
def get_search_results(

src/neo4j_genai/retrievers/external/pinecone/pinecone.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class PineconeNeo4jRetriever(ExternalRetriever):
7979
embedder (Optional[Embedder]): Embedder object to embed query text.
8080
return_properties (Optional[list[str]]): List of node properties to return.
8181
result_formatter (Optional[Callable[[Any], Any]]): Function to transform a neo4j.Record to a RetrieverResultItem.
82+
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
8283
8384
Raises:
8485
RetrieverInitializationError: If validation of the input arguments fail.
@@ -94,6 +95,7 @@ def __init__(
9495
return_properties: Optional[list[str]] = None,
9596
retrieval_query: Optional[str] = None,
9697
result_formatter: Optional[Callable[[Any], Any]] = None,
98+
neo4j_database: Optional[str] = None,
9799
):
98100
try:
99101
driver_model = Neo4jDriverModel(driver=driver)
@@ -108,6 +110,7 @@ def __init__(
108110
return_properties=return_properties,
109111
retrieval_query=retrieval_query,
110112
result_formatter=result_formatter,
113+
neo4j_database=neo4j_database,
111114
)
112115
except ValidationError as e:
113116
raise RetrieverInitializationError(e.errors()) from e
@@ -116,6 +119,7 @@ def __init__(
116119
driver=driver,
117120
id_property_external="id",
118121
id_property_neo4j=validated_data.id_property_neo4j,
122+
neo4j_database=neo4j_database,
119123
)
120124
self.driver = validated_data.driver_model.driver
121125
self.client = validated_data.client_model.client
@@ -224,6 +228,8 @@ def get_search_results(
224228
logger.debug("Pinecone Store Cypher parameters: %s", parameters)
225229
logger.debug("Pinecone Store Cypher query: %s", search_query)
226230

227-
records, _, _ = self.driver.execute_query(search_query, parameters)
231+
records, _, _ = self.driver.execute_query(
232+
search_query, parameters, database_=self.neo4j_database
233+
)
228234

229235
return RawSearchResult(records=records)

src/neo4j_genai/retrievers/external/pinecone/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,4 @@ class PineconeNeo4jRetrieverModel(BaseModel):
5353
return_properties: Optional[list[str]] = None
5454
retrieval_query: Optional[str] = None
5555
result_formatter: Optional[Callable[[neo4j.Record], str]] = None
56+
neo4j_database: Optional[str] = None

src/neo4j_genai/retrievers/external/weaviate/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class WeaviateNeo4jRetrieverModel(BaseModel):
5151
return_properties: Optional[list[str]] = None
5252
retrieval_query: Optional[str] = None
5353
result_formatter: Optional[Callable[[neo4j.Record], str]] = None
54+
neo4j_database: Optional[str] = None
5455

5556

5657
class WeaviateNeo4jSearchModel(VectorSearchModel):

src/neo4j_genai/retrievers/external/weaviate/weaviate.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class WeaviateNeo4jRetriever(ExternalRetriever):
7070
embedder (Optional[Embedder]): Embedder object to embed query text.
7171
return_properties (Optional[list[str]]): List of node properties to return.
7272
result_formatter (Optional[Callable[[Any], Any]]): Function to transform a neo4j.Record to a RetrieverResultItem.
73+
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
7374
7475
Raises:
7576
RetrieverInitializationError: If validation of the input arguments fail.
@@ -86,6 +87,7 @@ def __init__(
8687
return_properties: Optional[list[str]] = None,
8788
retrieval_query: Optional[str] = None,
8889
result_formatter: Optional[Callable[[Any], Any]] = None,
90+
neo4j_database: Optional[str] = None,
8991
):
9092
try:
9193
driver_model = Neo4jDriverModel(driver=driver)
@@ -101,11 +103,14 @@ def __init__(
101103
return_properties=return_properties,
102104
retrieval_query=retrieval_query,
103105
result_formatter=result_formatter,
106+
neo4j_database=neo4j_database,
104107
)
105108
except ValidationError as e:
106109
raise RetrieverInitializationError(e.errors()) from e
107110

108-
super().__init__(driver, id_property_external, id_property_neo4j)
111+
super().__init__(
112+
driver, id_property_external, id_property_neo4j, neo4j_database
113+
)
109114
self.client = validated_data.client_model.client
110115
collection = validated_data.collection
111116
self.search_collection = self.client.collections.get(collection)
@@ -227,6 +232,8 @@ def get_search_results(
227232
logger.debug("Weaviate Store Cypher parameters: %s", parameters)
228233
logger.debug("Weaviate Store Cypher query: %s", search_query)
229234

230-
records, _, _ = self.driver.execute_query(search_query, parameters)
235+
records, _, _ = self.driver.execute_query(
236+
search_query, parameters, database_=self.neo4j_database
237+
)
231238

232239
return RawSearchResult(records=records)

src/neo4j_genai/retrievers/hybrid.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ class HybridRetriever(Retriever):
6969
fulltext_index_name (str): Fulltext index name.
7070
embedder (Optional[Embedder]): Embedder object to embed query text.
7171
return_properties (Optional[list[str]]): List of node properties to return.
72+
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
73+
7274
"""
7375

7476
def __init__(
@@ -78,6 +80,7 @@ def __init__(
7880
fulltext_index_name: str,
7981
embedder: Optional[Embedder] = None,
8082
return_properties: Optional[list[str]] = None,
83+
neo4j_database: Optional[str] = None,
8184
) -> None:
8285
try:
8386
driver_model = Neo4jDriverModel(driver=driver)
@@ -88,11 +91,14 @@ def __init__(
8891
fulltext_index_name=fulltext_index_name,
8992
embedder_model=embedder_model,
9093
return_properties=return_properties,
94+
neo4j_database=neo4j_database,
9195
)
9296
except ValidationError as e:
9397
raise RetrieverInitializationError(e.errors()) from e
9498

95-
super().__init__(validated_data.driver_model.driver)
99+
super().__init__(
100+
validated_data.driver_model.driver, validated_data.neo4j_database
101+
)
96102
self.vector_index_name = validated_data.vector_index_name
97103
self.fulltext_index_name = validated_data.fulltext_index_name
98104
self.return_properties = validated_data.return_properties
@@ -173,7 +179,9 @@ def get_search_results(
173179
logger.debug("HybridRetriever Cypher parameters: %s", parameters)
174180
logger.debug("HybridRetriever Cypher query: %s", search_query)
175181

176-
records, _, _ = self.driver.execute_query(search_query, parameters)
182+
records, _, _ = self.driver.execute_query(
183+
search_query, parameters, database_=self.neo4j_database
184+
)
177185
return RawSearchResult(
178186
records=records,
179187
)
@@ -212,6 +220,7 @@ class HybridCypherRetriever(Retriever):
212220
retrieval_query (str): Cypher query that gets appended.
213221
embedder (Optional[Embedder]): Embedder object to embed query text.
214222
result_formatter (Optional[Callable[[Any], Any]]): Provided custom function to transform a neo4j.Record to a RetrieverResultItem.
223+
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to "neo4j" in the database (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
215224
216225
Raises:
217226
RetrieverInitializationError: If validation of the input arguments fail.
@@ -225,6 +234,7 @@ def __init__(
225234
retrieval_query: str,
226235
embedder: Optional[Embedder] = None,
227236
result_formatter: Optional[Callable[[Any], Any]] = None,
237+
neo4j_database: Optional[str] = None,
228238
) -> None:
229239
try:
230240
driver_model = Neo4jDriverModel(driver=driver)
@@ -235,11 +245,14 @@ def __init__(
235245
fulltext_index_name=fulltext_index_name,
236246
retrieval_query=retrieval_query,
237247
embedder_model=embedder_model,
248+
neo4j_database=neo4j_database,
238249
)
239250
except ValidationError as e:
240251
raise RetrieverInitializationError(e.errors()) from e
241252

242-
super().__init__(validated_data.driver_model.driver)
253+
super().__init__(
254+
validated_data.driver_model.driver, validated_data.neo4j_database
255+
)
243256
self.vector_index_name = validated_data.vector_index_name
244257
self.fulltext_index_name = validated_data.fulltext_index_name
245258
self.retrieval_query = validated_data.retrieval_query
@@ -316,7 +329,9 @@ def get_search_results(
316329
logger.debug("HybridCypherRetriever Cypher parameters: %s", parameters)
317330
logger.debug("HybridCypherRetriever Cypher query: %s", search_query)
318331

319-
records, _, _ = self.driver.execute_query(search_query, parameters)
332+
records, _, _ = self.driver.execute_query(
333+
search_query, parameters, database_=self.neo4j_database
334+
)
320335
return RawSearchResult(
321336
records=records,
322337
)

0 commit comments

Comments
 (0)