Skip to content

Commit 4dae4eb

Browse files
authored
Refactors upsert vector function (#272)
* Added upsert_embeddings function * Added tests for upsert_embeddings * Updated docs and README * Added deprecation warnings to docstrings * Renamed function * Small variable name changes * Formatting fixes * Added deprecation comments * Updated CHANGELOG * Updated upsert_vectors docstring * Fixed README issue * Updated docs
1 parent a414ca6 commit 4dae4eb

File tree

9 files changed

+279
-14
lines changed

9 files changed

+279
-14
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,13 @@
66

77
- Utility functions to retrieve metadata for vector and full-text indexes.
88
- Support for effective_search_ratio parameter in vector and hybrid searches.
9+
- Introduced upsert_vectors utility function for batch upserting embeddings to vector indexes.
910

1011
### Changed
1112

1213
- Refactored index-related functions for improved compatibility and functionality.
14+
- Added deprecation warnings to upsert_vector, upsert_vector_on_relationship functions in favor of upsert_vectors.
15+
- Added deprecation warnings to async_upsert_vector, async_upsert_vector_on_relationship functions notifying developers that they will be removed in a future release.
1316

1417
## 1.4.3
1518

README.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,8 @@ Ensure that your vector index is created prior to executing this example.
194194
```python
195195
from neo4j import GraphDatabase
196196
from neo4j_graphrag.embeddings import OpenAIEmbeddings
197-
from neo4j_graphrag.indexes import upsert_vector
197+
from neo4j_graphrag.indexes import upsert_vectors
198+
from neo4j_graphrag.types import EntityType
198199

199200
NEO4J_URI = "neo4j://localhost:7687"
200201
NEO4J_USERNAME = "neo4j"
@@ -214,11 +215,12 @@ text = (
214215
vector = embedder.embed_query(text)
215216

216217
# Upsert the vector
217-
upsert_vector(
218+
upsert_vectors(
218219
driver,
219-
node_id=0,
220-
embedding_property="embedding",
221-
vector=vector,
220+
ids=["1234"],
221+
embedding_property="vectorProperty",
222+
embeddings=[vector],
223+
entity_type=EntityType.NODE,
222224
)
223225
driver.close()
224226
```

docs/source/api.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,8 @@ Database Interaction
382382

383383
.. autofunction:: neo4j_graphrag.indexes.drop_index_if_exists
384384

385+
.. autofunction:: neo4j_graphrag.indexes.upsert_vectors
386+
385387
.. autofunction:: neo4j_graphrag.indexes.upsert_vector
386388

387389
.. autofunction:: neo4j_graphrag.indexes.upsert_vector_on_relationship

docs/source/index.rst

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,9 @@ Note that the below example is not the only way you can upsert data into your Ne
148148

149149

150150
.. code:: python
151-
152151
from neo4j import GraphDatabase
153-
from neo4j_graphrag.indexes import upsert_vector
152+
from neo4j_graphrag.indexes import upsert_vectors
153+
from neo4j_graphrag.types import EntityType
154154
155155
URI = "neo4j://localhost:7687"
156156
AUTH = ("neo4j", "password")
@@ -160,11 +160,12 @@ Note that the below example is not the only way you can upsert data into your Ne
160160
161161
# Upsert the vector
162162
vector = ...
163-
upsert_vector(
163+
upsert_vectors(
164164
driver,
165-
node_id=1,
165+
ids=["1234"],
166166
embedding_property="vectorProperty",
167-
vector=vector,
167+
embeddings=[vector],
168+
entity_type=EntityType.NODE,
168169
)
169170
170171

docs/source/user_guide_rag.rst

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -917,9 +917,11 @@ Populate a Vector Index
917917
==========================
918918

919919
.. code:: python
920+
from random import random
920921
921922
from neo4j import GraphDatabase
922-
from random import random
923+
from neo4j_graphrag.indexes import upsert_vectors
924+
from neo4j_graphrag.types import EntityType
923925
924926
URI = "neo4j://localhost:7687"
925927
AUTH = ("neo4j", "password")
@@ -928,10 +930,17 @@ Populate a Vector Index
928930
driver = GraphDatabase.driver(URI, auth=AUTH)
929931
930932
# Upsert the vector
933+
DIMENSION = 1536
931934
vector = [random() for _ in range(DIMENSION)]
932-
upsert_vector(driver, node_id="1234", embedding_property="embedding", vector=vector)
935+
upsert_vectors(
936+
driver,
937+
ids=["1234"],
938+
embedding_property="vectorProperty",
939+
embeddings=[vector],
940+
entity_type=EntityType.NODE,
941+
)
933942
934-
This will update the node with `id(node)=1234` to add (or update) a `node.embedding` property.
943+
This will update the node with `id(node)=1234` to add (or update) a `node.vectorProperty` property.
935944
This property will also be added to the vector index.
936945

937946

src/neo4j_graphrag/indexes.py

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import logging
18+
import warnings
1819
from typing import List, Literal, Optional
1920

2021
import neo4j
@@ -23,10 +24,12 @@
2324
from neo4j_graphrag.neo4j_queries import (
2425
UPSERT_VECTOR_ON_NODE_QUERY,
2526
UPSERT_VECTOR_ON_RELATIONSHIP_QUERY,
27+
UPSERT_VECTORS_ON_NODE_QUERY,
28+
UPSERT_VECTORS_ON_RELATIONSHIP_QUERY,
2629
)
2730

2831
from .exceptions import Neo4jIndexError, Neo4jInsertionError
29-
from .types import FulltextIndexModel, VectorIndexModel
32+
from .types import EntityType, FulltextIndexModel, VectorIndexModel
3033

3134
logger = logging.getLogger(__name__)
3235

@@ -245,6 +248,86 @@ def drop_index_if_exists(
245248
raise Neo4jIndexError(f"Dropping Neo4j index failed: {e.message}") from e
246249

247250

251+
def upsert_vectors(
252+
driver: neo4j.Driver,
253+
ids: List[str],
254+
embedding_property: str,
255+
embeddings: List[List[float]],
256+
neo4j_database: Optional[str] = None,
257+
entity_type: EntityType = EntityType.NODE,
258+
) -> None:
259+
"""
260+
This method constructs a Cypher query and executes it to upsert
261+
(insert or update) embeddings on a set of nodes or relationships.
262+
263+
Example:
264+
265+
.. code-block:: python
266+
267+
from neo4j import GraphDatabase
268+
from neo4j_graphrag.indexes import upsert_vectors
269+
270+
URI = "neo4j://localhost:7687"
271+
AUTH = ("neo4j", "password")
272+
273+
# Connect to Neo4j database
274+
driver = GraphDatabase.driver(URI, auth=AUTH)
275+
276+
# Upsert embeddings data for several nodes
277+
upsert_vectors(
278+
driver,
279+
ids=['123', '456', '789'],
280+
embedding_property="vectorProperty",
281+
embeddings=[
282+
[0.12, 0.34, 0.56],
283+
[0.78, 0.90, 0.12],
284+
[0.34, 0.56, 0.78],
285+
],
286+
neo4j_database="neo4j",
287+
entity_type='NODE',
288+
)
289+
290+
Args:
291+
driver (neo4j.Driver): Neo4j Python driver instance.
292+
ids (List[int]): The element IDs of the nodes or relationships.
293+
embedding_property (str): The name of the property to store the vectors in.
294+
embeddings (List[List[float]]): The list of vectors to store, one per ID.
295+
neo4j_database (Optional[str]): The name of the Neo4j database.
296+
If not provided, defaults to the server's default database. 'neo4j' by default.
297+
entity_type (EntityType): Specifies whether to upsert to nodes ('NODE') or relationships ('RELATIONSHIP').
298+
Defaults to 'NODE'.
299+
300+
Raises:
301+
ValueError: If the lengths of IDs and embeddings do not match, or if embeddings are not of uniform dimension.
302+
Neo4jInsertionError: If an error occurs while attempting to upsert the vectors in Neo4j.
303+
"""
304+
if entity_type == EntityType.NODE:
305+
query = UPSERT_VECTORS_ON_NODE_QUERY
306+
elif entity_type == EntityType.RELATIONSHIP:
307+
query = UPSERT_VECTORS_ON_RELATIONSHIP_QUERY
308+
else:
309+
raise ValueError("entity_type must be either 'NODE' or 'RELATIONSHIP'")
310+
if len(ids) != len(embeddings):
311+
raise ValueError("ids and embeddings must be the same length")
312+
if not all(len(embedding) == len(embeddings[0]) for embedding in embeddings):
313+
raise ValueError("All embeddings must be of the same size")
314+
try:
315+
parameters = {
316+
"rows": [
317+
{"id": id, "embedding": embedding}
318+
for id, embedding in zip(ids, embeddings)
319+
],
320+
"embedding_property": embedding_property,
321+
}
322+
driver.execute_query(
323+
query_=query, parameters_=parameters, database_=neo4j_database
324+
)
325+
except neo4j.exceptions.ClientError as e:
326+
raise Neo4jInsertionError(
327+
f"Upserting vectors to Neo4j failed: {e.message}"
328+
) from e
329+
330+
248331
def upsert_vector(
249332
driver: neo4j.Driver,
250333
node_id: int,
@@ -253,6 +336,9 @@ def upsert_vector(
253336
neo4j_database: Optional[str] = None,
254337
) -> None:
255338
"""
339+
.. warning::
340+
'upsert_vector' is deprecated and will be removed in a future version, please use 'upsert_vectors' instead.
341+
256342
This method constructs a Cypher query and executes it to upsert (insert or update) a vector property on a specific node.
257343
258344
Example:
@@ -286,6 +372,11 @@ def upsert_vector(
286372
Raises:
287373
Neo4jInsertionError: If upserting of the vector fails.
288374
"""
375+
warnings.warn(
376+
"'upsert_vector' is deprecated and will be removed in a future version, please use 'upsert_vectors' instead.",
377+
DeprecationWarning,
378+
stacklevel=2,
379+
)
289380
try:
290381
parameters = {
291382
"node_element_id": node_id,
@@ -309,6 +400,9 @@ def upsert_vector_on_relationship(
309400
neo4j_database: Optional[str] = None,
310401
) -> None:
311402
"""
403+
.. warning::
404+
'upsert_vector_on_relationship' is deprecated and will be removed in a future version, please use 'upsert_vectors' instead.
405+
312406
This method constructs a Cypher query and executes it to upsert (insert or update) a vector property on a specific relationship.
313407
314408
Example:
@@ -342,6 +436,11 @@ def upsert_vector_on_relationship(
342436
Raises:
343437
Neo4jInsertionError: If upserting of the vector fails.
344438
"""
439+
warnings.warn(
440+
"'upsert_vector_on_relationship' is deprecated and will be removed in a future version, please use 'upsert_vectors' instead.",
441+
DeprecationWarning,
442+
stacklevel=2,
443+
)
345444
try:
346445
parameters = {
347446
"rel_element_id": rel_id,
@@ -365,6 +464,9 @@ async def async_upsert_vector(
365464
neo4j_database: Optional[str] = None,
366465
) -> None:
367466
"""
467+
.. warning::
468+
'async_upsert_vector' is deprecated and will be removed in a future version.
469+
368470
This method constructs a Cypher query and asynchronously executes it
369471
to upsert (insert or update) a vector property on a specific node.
370472
@@ -399,6 +501,11 @@ async def async_upsert_vector(
399501
Raises:
400502
Neo4jInsertionError: If upserting of the vector fails.
401503
"""
504+
warnings.warn(
505+
"'async_upsert_vector' is deprecated and will be removed in a future version.",
506+
DeprecationWarning,
507+
stacklevel=2,
508+
)
402509
try:
403510
parameters = {
404511
"node_id": node_id,
@@ -422,6 +529,9 @@ async def async_upsert_vector_on_relationship(
422529
neo4j_database: Optional[str] = None,
423530
) -> None:
424531
"""
532+
.. warning::
533+
'async_upsert_vector_on_relationship' is deprecated and will be removed in a future version.
534+
425535
This method constructs a Cypher query and asynchronously executes it
426536
to upsert (insert or update) a vector property on a specific relationship.
427537
@@ -456,6 +566,11 @@ async def async_upsert_vector_on_relationship(
456566
Raises:
457567
Neo4jInsertionError: If upserting of the vector fails.
458568
"""
569+
warnings.warn(
570+
"'async_upsert_vector_on_relationship' is deprecated and will be removed in a future version.",
571+
DeprecationWarning,
572+
stacklevel=2,
573+
)
459574
try:
460575
parameters = {
461576
"rel_id": rel_id,

src/neo4j_graphrag/neo4j_queries.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@
107107
"RETURN elementId(rel)"
108108
)
109109

110+
# Deprecated, remove along with upsert_vector
110111
UPSERT_VECTOR_ON_NODE_QUERY = (
111112
"MATCH (n) "
112113
"WHERE elementId(n) = $node_element_id "
@@ -115,6 +116,16 @@
115116
"RETURN n"
116117
)
117118

119+
UPSERT_VECTORS_ON_NODE_QUERY = (
120+
"UNWIND $rows AS row "
121+
"MATCH (n) "
122+
"WHERE elementId(n) = row.id "
123+
"WITH n, row "
124+
"CALL db.create.setNodeVectorProperty(n, $embedding_property, row.embedding) "
125+
"RETURN n"
126+
)
127+
128+
# Deprecated, remove along with upsert_vector_on_relationship
118129
UPSERT_VECTOR_ON_RELATIONSHIP_QUERY = (
119130
"MATCH ()-[r]->() "
120131
"WHERE elementId(r) = $rel_element_id "
@@ -123,6 +134,15 @@
123134
"RETURN r"
124135
)
125136

137+
UPSERT_VECTORS_ON_RELATIONSHIP_QUERY = (
138+
"UNWIND $rows AS row "
139+
"MATCH ()-[r]->() "
140+
"WHERE elementId(r) = row.id "
141+
"WITH r, row "
142+
"CALL db.create.setRelationshipVectorProperty(r, $embedding_property, row.embedding) "
143+
"RETURN r"
144+
)
145+
126146

127147
def _get_hybrid_query(neo4j_version_is_5_23_or_above: bool) -> str:
128148
"""

0 commit comments

Comments
 (0)