Skip to content

Commit 59ac9d9

Browse files
authored
Removes Pinecone and Weaviate retrievers from __init__.py (neo4j#59)
* Add try/catch in __init__.py for optional dependencies * Removed pinecone and weaviate retrievers from init * Resolve type error resulting from sentence_tranformers .encode method * Updated CHANGELOG
1 parent 83b2002 commit 59ac9d9

File tree

7 files changed

+155
-156
lines changed

7 files changed

+155
-156
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@
2020
### Fixed
2121
- Updated documentation to include new custom exceptions.
2222
- Improved the use of Pydantic for input data validation for retriever objects.
23+
- Removed Pinecone and Weaviate retrievers from __init__.py to prevent ImportError when optional dependencies are not installed.

poetry.lock

Lines changed: 137 additions & 146 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/neo4j_genai/__init__.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,15 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from .retrievers.external.pinecone import PineconeNeo4jRetriever
17-
from .retrievers.external.weaviate import WeaviateNeo4jRetriever
1816
from .retrievers.hybrid import HybridCypherRetriever, HybridRetriever
1917
from .retrievers.text2cypher import Text2CypherRetriever
2018
from .retrievers.vector import VectorCypherRetriever, VectorRetriever
2119

20+
2221
__all__ = [
2322
"VectorRetriever",
2423
"VectorCypherRetriever",
2524
"HybridRetriever",
2625
"HybridCypherRetriever",
2726
"Text2CypherRetriever",
28-
"WeaviateNeo4jRetriever",
29-
"PineconeNeo4jRetriever",
3027
]

src/neo4j_genai/embeddings/sentence_transformers.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from neo4j_genai.embedder import Embedder
22
from typing import Any
3+
import torch
4+
import numpy as np
35

46

57
class SentenceTransformerEmbeddings(Embedder):
@@ -17,4 +19,12 @@ def __init__(
1719
self.model = SentenceTransformer(model, *args, **kwargs)
1820

1921
def embed_query(self, text: str) -> Any:
20-
return self.model.encode([text]).flatten().tolist()
22+
result = self.model.encode([text])
23+
if isinstance(result, torch.Tensor) or isinstance(result, np.ndarray):
24+
return result.flatten().tolist()
25+
elif isinstance(result, list) and all(
26+
isinstance(x, torch.Tensor) for x in result
27+
):
28+
return [item for tensor in result for item in tensor.flatten().tolist()]
29+
else:
30+
raise ValueError("Unexpected return type from model encoding")

src/neo4j_genai/retrievers/external/pinecone/pinecone.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class PineconeNeo4jRetriever(ExternalRetriever):
5252
.. code-block:: python
5353
5454
from neo4j import GraphDatabase
55-
from neo4j_genai import PineconeNeo4jRetriever
55+
from neo4j_genai.retrievers.external.pinecone import PineconeNeo4jRetriever
5656
from pinecone import Pinecone
5757
5858
with GraphDatabase.driver(NEO4J_URL, auth=NEO4J_AUTH) as neo4j_driver:
@@ -153,7 +153,7 @@ def _get_search_results(
153153
.. code-block:: python
154154
155155
from neo4j import GraphDatabase
156-
from neo4j_genai import PineconeNeo4jRetriever
156+
from neo4j_genai.retrievers.external.pinecone import PineconeNeo4jRetriever
157157
from pinecone import Pinecone
158158
159159
with GraphDatabase.driver(NEO4J_URL, auth=NEO4J_AUTH) as neo4j_driver:

src/neo4j_genai/retrievers/external/weaviate/weaviate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class WeaviateNeo4jRetriever(ExternalRetriever):
4545
.. code-block:: python
4646
4747
from neo4j import GraphDatabase
48-
from neo4j_genai import WeaviateNeo4jRetriever
48+
from neo4j_genai.retrievers.external.weaviate import WeaviateNeo4jRetriever
4949
from weaviate.connect.helpers import connect_to_local
5050
5151
with GraphDatabase.driver(NEO4J_URL, auth=NEO4J_AUTH) as neo4j_driver:
@@ -136,7 +136,7 @@ def _get_search_results(
136136
.. code-block:: python
137137
138138
import neo4j
139-
from neo4j_genai.retrievers import WeaviateNeo4jRetriever
139+
from neo4j_genai.retrievers.external.weaviate import WeaviateNeo4jRetriever
140140
141141
driver = neo4j.GraphDatabase.driver(URI, auth=AUTH)
142142

tests/e2e/pinecone_e2e/test_pinecone_e2e.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import pytest
2121
from langchain_community.embeddings import HuggingFaceEmbeddings
22-
from neo4j_genai import PineconeNeo4jRetriever
22+
from neo4j_genai.retrievers.external.pinecone import PineconeNeo4jRetriever
2323
from neo4j_genai.embedder import Embedder
2424
from neo4j_genai.types import RetrieverResult, RetrieverResultItem
2525
from pinecone import Pinecone

0 commit comments

Comments
 (0)