Skip to content

Commit a99a343

Browse files
authored
Pydantic search model refactoring (neo4j#69)
* Refactored Pydantic search models * Added VectorCypherSearchModel and HybridCypherSearchModel back in * Minor Pinecone and Weaviate Pydantic model refactoring * Removed unneeded imports * Removed filters from parameters in vector search * Changed List, Dict type hints to list, dict
1 parent 6fbd38a commit a99a343

File tree

6 files changed

+54
-61
lines changed

6 files changed

+54
-61
lines changed

src/neo4j_genai/retrievers/external/pinecone/pinecone.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,15 +179,14 @@ def get_search_results(
179179
RawSearchResult: The results of the search query as a list of neo4j.Record and an optional metadata dict
180180
"""
181181

182-
pinecone_filters = kwargs.get("pinecone_filters")
182+
pinecone_filter = kwargs.get("pinecone_filter")
183183

184184
try:
185185
validated_data = PineconeSearchModel(
186-
vector_index_name=self.index_name,
187186
query_vector=query_vector,
188187
query_text=query_text,
189188
top_k=top_k,
190-
pinecone_filter=pinecone_filters,
189+
pinecone_filter=pinecone_filter,
191190
)
192191
except ValidationError as e:
193192
raise SearchValidationError(e.errors()) from e

src/neo4j_genai/retrievers/external/pinecone/types.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
from __future__ import annotations
16-
from typing import Any, Callable, Optional
16+
17+
from typing import Any, Callable, Optional, Union
1718

1819
import neo4j
1920
from pinecone import Pinecone
@@ -27,7 +28,9 @@
2728

2829

2930
class PineconeSearchModel(VectorSearchModel):
30-
pinecone_filter: Optional[dict[str, Any]] = None
31+
pinecone_filter: Optional[
32+
dict[str, Union[str, float, int, bool, list[Any], dict[Any, Any]]]
33+
] = None
3134

3235

3336
class PineconeClientModel(BaseModel):

src/neo4j_genai/retrievers/external/weaviate/types.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,19 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
from __future__ import annotations
16-
from typing import Optional, Any, Callable
16+
17+
from typing import Callable, Optional
1718

1819
import neo4j
1920
from pydantic import (
20-
field_validator,
2121
BaseModel,
22-
PositiveInt,
23-
model_validator,
2422
ConfigDict,
23+
field_validator,
2524
)
2625
from weaviate.client import WeaviateClient
2726
from weaviate.collections.classes.filters import _Filters
2827

29-
from neo4j_genai.utils import validate_search_query_input
30-
from neo4j_genai.types import Neo4jDriverModel, EmbedderModel
28+
from neo4j_genai.types import EmbedderModel, Neo4jDriverModel, VectorSearchModel
3129

3230

3331
class WeaviateModel(BaseModel):
@@ -55,10 +53,7 @@ class WeaviateNeo4jRetrieverModel(BaseModel):
5553
result_formatter: Optional[Callable[[neo4j.Record], str]] = None
5654

5755

58-
class WeaviateNeo4jSearchModel(BaseModel):
59-
top_k: PositiveInt = 5
60-
query_vector: Optional[list[float]] = None
61-
query_text: Optional[str] = None
56+
class WeaviateNeo4jSearchModel(VectorSearchModel):
6257
weaviate_filters: Optional[_Filters] = None
6358
model_config = ConfigDict(arbitrary_types_allowed=True)
6459

@@ -69,12 +64,3 @@ def check_weaviate_filters(cls, value: _Filters) -> _Filters:
6964
"Provided filters need to be of type weaviate.collections.classes.filters._Filters"
7065
)
7166
return value
72-
73-
@model_validator(mode="before")
74-
def check_query(cls, values: dict[str, Any]) -> dict[str, Any]:
75-
"""
76-
Validates that one of either query_vector or query_text is provided exclusively.
77-
"""
78-
query_vector, query_text = values.get("query_vector"), values.get("query_text")
79-
validate_search_query_input(query_text, query_vector)
80-
return values

src/neo4j_genai/retrievers/hybrid.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,31 +13,32 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
from __future__ import annotations
16-
from typing import Optional, Any, Callable
16+
17+
import logging
18+
from typing import Any, Callable, Optional
1719

1820
import neo4j
1921
from pydantic import ValidationError
2022

2123
from neo4j_genai.embedder import Embedder
2224
from neo4j_genai.exceptions import (
25+
EmbeddingRequiredError,
2326
RetrieverInitializationError,
2427
SearchValidationError,
25-
EmbeddingRequiredError,
2628
)
29+
from neo4j_genai.neo4j_queries import get_search_query
2730
from neo4j_genai.retrievers.base import Retriever
2831
from neo4j_genai.types import (
29-
HybridSearchModel,
30-
SearchType,
31-
HybridCypherSearchModel,
32-
Neo4jDriverModel,
3332
EmbedderModel,
34-
HybridRetrieverModel,
3533
HybridCypherRetrieverModel,
34+
HybridCypherSearchModel,
35+
HybridRetrieverModel,
36+
HybridSearchModel,
37+
Neo4jDriverModel,
3638
RawSearchResult,
3739
RetrieverResultItem,
40+
SearchType,
3841
)
39-
from neo4j_genai.neo4j_queries import get_search_query
40-
import logging
4142

4243
logger = logging.getLogger(__name__)
4344

@@ -146,16 +147,16 @@ def get_search_results(
146147
"""
147148
try:
148149
validated_data = HybridSearchModel(
149-
vector_index_name=self.vector_index_name,
150-
fulltext_index_name=self.fulltext_index_name,
151-
top_k=top_k,
152150
query_vector=query_vector,
153151
query_text=query_text,
152+
top_k=top_k,
154153
)
155154
except ValidationError as e:
156155
raise SearchValidationError(e.errors()) from e
157156

158157
parameters = validated_data.model_dump(exclude_none=True)
158+
parameters["vector_index_name"] = self.vector_index_name
159+
parameters["fulltext_index_name"] = self.fulltext_index_name
159160

160161
if query_text and not query_vector:
161162
if not self.embedder:
@@ -276,17 +277,17 @@ def get_search_results(
276277
"""
277278
try:
278279
validated_data = HybridCypherSearchModel(
279-
vector_index_name=self.vector_index_name,
280-
fulltext_index_name=self.fulltext_index_name,
281-
top_k=top_k,
282280
query_vector=query_vector,
283281
query_text=query_text,
282+
top_k=top_k,
284283
query_params=query_params,
285284
)
286285
except ValidationError as e:
287286
raise SearchValidationError(e.errors()) from e
288287

289288
parameters = validated_data.model_dump(exclude_none=True)
289+
parameters["vector_index_name"] = self.vector_index_name
290+
parameters["fulltext_index_name"] = self.fulltext_index_name
290291

291292
if query_text and not query_vector:
292293
if not self.embedder:

src/neo4j_genai/retrievers/vector.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,32 +13,32 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
from __future__ import annotations
16-
from typing import Optional, Any, Callable
16+
17+
import logging
18+
from typing import Any, Callable, Optional
1719

1820
import neo4j
21+
from pydantic import ValidationError
1922

23+
from neo4j_genai.embedder import Embedder
2024
from neo4j_genai.exceptions import (
25+
EmbeddingRequiredError,
2126
RetrieverInitializationError,
2227
SearchValidationError,
23-
EmbeddingRequiredError,
2428
)
29+
from neo4j_genai.neo4j_queries import get_search_query
2530
from neo4j_genai.retrievers.base import Retriever
26-
from pydantic import ValidationError
27-
28-
from neo4j_genai.embedder import Embedder
2931
from neo4j_genai.types import (
30-
VectorSearchModel,
31-
VectorCypherSearchModel,
32-
SearchType,
33-
Neo4jDriverModel,
3432
EmbedderModel,
35-
VectorRetrieverModel,
36-
VectorCypherRetrieverModel,
33+
Neo4jDriverModel,
3734
RawSearchResult,
3835
RetrieverResultItem,
36+
SearchType,
37+
VectorCypherRetrieverModel,
38+
VectorCypherSearchModel,
39+
VectorRetrieverModel,
40+
VectorSearchModel,
3941
)
40-
from neo4j_genai.neo4j_queries import get_search_query
41-
import logging
4242

4343
logger = logging.getLogger(__name__)
4444

@@ -146,15 +146,18 @@ def get_search_results(
146146
"""
147147
try:
148148
validated_data = VectorSearchModel(
149-
vector_index_name=self.index_name,
150-
top_k=top_k,
151149
query_vector=query_vector,
152150
query_text=query_text,
151+
top_k=top_k,
152+
filters=filters,
153153
)
154154
except ValidationError as e:
155155
raise SearchValidationError(e.errors()) from e
156156

157157
parameters = validated_data.model_dump(exclude_none=True)
158+
parameters["vector_index_name"] = self.index_name
159+
if filters:
160+
del parameters["filters"]
158161

159162
if query_text:
160163
if not self.embedder:
@@ -275,16 +278,19 @@ def get_search_results(
275278
"""
276279
try:
277280
validated_data = VectorCypherSearchModel(
278-
vector_index_name=self.index_name,
279-
top_k=top_k,
280281
query_vector=query_vector,
281282
query_text=query_text,
283+
top_k=top_k,
282284
query_params=query_params,
285+
filters=filters,
283286
)
284287
except ValidationError as e:
285288
raise SearchValidationError(e.errors()) from e
286289

287290
parameters = validated_data.model_dump(exclude_none=True)
291+
parameters["vector_index_name"] = self.index_name
292+
if filters:
293+
del parameters["filters"]
288294

289295
if query_text:
290296
if not self.embedder:

src/neo4j_genai/types.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,10 @@ def check_node_properties_not_empty(cls, v: list[Any]) -> list[Any]:
116116

117117

118118
class VectorSearchModel(BaseModel):
119-
vector_index_name: str
120-
top_k: PositiveInt = 5
121119
query_vector: Optional[list[float]] = None
122120
query_text: Optional[str] = None
121+
top_k: PositiveInt = 5
122+
filters: Optional[dict[str, Any]] = None
123123

124124
@model_validator(mode="before")
125125
def check_query(cls, values: dict[str, Any]) -> dict[str, Any]:
@@ -136,11 +136,9 @@ class VectorCypherSearchModel(VectorSearchModel):
136136

137137

138138
class HybridSearchModel(BaseModel):
139-
vector_index_name: str
140-
fulltext_index_name: str
141139
query_text: str
142-
top_k: PositiveInt = 5
143140
query_vector: Optional[list[float]] = None
141+
top_k: PositiveInt = 5
144142

145143

146144
class HybridCypherSearchModel(HybridSearchModel):

0 commit comments

Comments
 (0)