Skip to content

Commit 607e667

Browse files
authored
Stopped the embedding from being returned with VectorRetriever (neo4j#58)
* Stopped the embedding from being returned with VectorRetriever * Updated test_neo4j_queries.py * Removed apoc calls * Added nodeLabels back in * Added nodeLabels and id to VectorRetriever metadata * Updated CHANGELOG * Updated CHANGELOG
1 parent 0769236 commit 607e667

File tree

6 files changed

+26
-10
lines changed

6 files changed

+26
-10
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
## Next
44

5+
### Added
6+
- Stopped embeddings from being returned when searching with `VectorRetriever`. Added `nodeLabels` and `id` to the metadata of `VectorRetriever` results.
7+
58
## 0.2.0
69

710
### Fixed

src/neo4j_genai/neo4j_queries.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,9 @@ def get_search_query(
147147
else:
148148
raise ValueError(f"Search type is not supported: {search_type}")
149149
query_tail = get_query_tail(
150-
retrieval_query, return_properties, fallback_return="RETURN node, score"
150+
retrieval_query,
151+
return_properties,
152+
fallback_return=f"RETURN node {{ .*, `{embedding_node_property}`: null }} AS node, labels(node) AS nodeLabels, elementId(node) AS id, score",
151153
)
152154
return f"{query} {query_tail}", params
153155

src/neo4j_genai/retrievers/vector.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ def default_format_record(self, record: neo4j.Record) -> RetrieverResultItem:
108108
"""
109109
metadata = {
110110
"score": record.get("score"),
111+
"nodeLabels": record.get("nodeLabels"),
112+
"id": record.get("id"),
111113
}
112114
node = record.get("node")
113115
return RetrieverResultItem(

tests/e2e/test_vector_e2e.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def test_vector_retriever_search_text(
3232
assert isinstance(results, RetrieverResult)
3333
assert len(results.items) == 5
3434
for result in results.items:
35+
assert f"'{retriever._embedding_node_property}': None" in result.content
3536
assert isinstance(result, RetrieverResultItem)
3637

3738

@@ -66,6 +67,7 @@ def test_vector_retriever_search_vector(driver: Driver) -> None:
6667
assert isinstance(results, RetrieverResult)
6768
assert len(results.items) == 5
6869
for result in results.items:
70+
assert f"'{retriever._embedding_node_property}': None" in result.content
6971
assert isinstance(result, RetrieverResultItem)
7072

7173

tests/unit/retrievers/test_vector.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ def test_similarity_search_vector_happy_path(
101101
assert records == RetrieverResult(
102102
items=[
103103
RetrieverResultItem(
104-
content="{'text': 'dummy-node'}", metadata={"score": 1.0}
104+
content="{'text': 'dummy-node'}",
105+
metadata={"score": 1.0, "nodeLabels": None, "id": None},
105106
),
106107
],
107108
metadata={"__retriever": "VectorRetriever"},
@@ -143,7 +144,10 @@ def test_similarity_search_text_happy_path(
143144
)
144145
assert records == RetrieverResult(
145146
items=[
146-
RetrieverResultItem(content="dummy-node", metadata={"score": 1.0}),
147+
RetrieverResultItem(
148+
content="dummy-node",
149+
metadata={"score": 1.0, "nodeLabels": None, "id": None},
150+
),
147151
],
148152
metadata={"__retriever": "VectorRetriever"},
149153
)
@@ -191,7 +195,10 @@ def test_similarity_search_text_return_properties(
191195
)
192196
assert records == RetrieverResult(
193197
items=[
194-
RetrieverResultItem(content="dummy-node", metadata={"score": 1.0}),
198+
RetrieverResultItem(
199+
content="dummy-node",
200+
metadata={"score": 1.0, "nodeLabels": None, "id": None},
201+
),
195202
],
196203
metadata={"__retriever": "VectorRetriever"},
197204
)

tests/unit/test_neo4j_queries.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test_vector_search_basic() -> None:
2222
expected = (
2323
"CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) "
2424
"YIELD node, score "
25-
"RETURN node, score"
25+
"RETURN node { .*, `None`: null } AS node, labels(node) AS nodeLabels, elementId(node) AS id, score"
2626
)
2727
result, params = get_search_query(SearchType.VECTOR)
2828
assert result.strip() == expected.strip()
@@ -42,7 +42,7 @@ def test_hybrid_search_basic() -> None:
4242
"RETURN n.node AS node, (n.score / max) AS score "
4343
"} "
4444
"WITH node, max(score) AS score ORDER BY score DESC LIMIT $top_k "
45-
"RETURN node, score"
45+
"RETURN node { .*, `None`: null } AS node, labels(node) AS nodeLabels, elementId(node) AS id, score"
4646
)
4747
result, _ = get_search_query(SearchType.HYBRID)
4848
assert result.strip() == expected.strip()
@@ -78,8 +78,8 @@ def test_vector_search_with_filters(_mock: Any) -> None:
7878
" AND (True) "
7979
"WITH node, "
8080
"vector.similarity.cosine(node.`vector`, $query_vector) AS score "
81-
"ORDER BY score DESC LIMIT $top_k"
82-
" RETURN node, score"
81+
"ORDER BY score DESC LIMIT $top_k "
82+
"RETURN node { .*, `vector`: null } AS node, labels(node) AS nodeLabels, elementId(node) AS id, score"
8383
)
8484
result, params = get_search_query(
8585
SearchType.VECTOR,
@@ -104,8 +104,8 @@ def test_vector_search_with_params_from_filters(_mock: Any) -> None:
104104
" AND (True) "
105105
"WITH node, "
106106
"vector.similarity.cosine(node.`vector`, $query_vector) AS score "
107-
"ORDER BY score DESC LIMIT $top_k"
108-
" RETURN node, score"
107+
"ORDER BY score DESC LIMIT $top_k "
108+
"RETURN node { .*, `vector`: null } AS node, labels(node) AS nodeLabels, elementId(node) AS id, score"
109109
)
110110
result, params = get_search_query(
111111
SearchType.VECTOR,

0 commit comments

Comments
 (0)