Skip to content

Commit cede84d

Browse files
committed
Improve import system for sentence_transformers
1 parent 606d78e commit cede84d

File tree

2 files changed

+34
-23
lines changed

2 files changed

+34
-23
lines changed

src/neo4j_graphrag/embeddings/sentence_transformers.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,34 +15,32 @@
1515

1616
from typing import Any
1717

18-
try:
19-
import numpy as np
20-
import sentence_transformers
21-
import torch
22-
except ImportError:
23-
sentence_transformers = None # type: ignore
24-
25-
2618
from neo4j_graphrag.embeddings.base import Embedder
2719

2820

2921
class SentenceTransformerEmbeddings(Embedder):
3022
def __init__(
3123
self, model: str = "all-MiniLM-L6-v2", *args: Any, **kwargs: Any
3224
) -> None:
33-
if sentence_transformers is None:
25+
try:
26+
import numpy as np
27+
import sentence_transformers
28+
import torch
29+
except ImportError:
3430
raise ImportError(
3531
"Could not import sentence_transformers python package. "
3632
"Please install it with `pip install sentence-transformers`."
3733
)
34+
self.torch = torch
35+
self.np = np
3836
self.model = sentence_transformers.SentenceTransformer(model, *args, **kwargs)
3937

4038
def embed_query(self, text: str) -> Any:
4139
result = self.model.encode([text])
42-
if isinstance(result, torch.Tensor) or isinstance(result, np.ndarray):
40+
if isinstance(result, self.torch.Tensor) or isinstance(result, self.np.ndarray):
4341
return result.flatten().tolist()
4442
elif isinstance(result, list) and all(
45-
isinstance(x, torch.Tensor) for x in result
43+
isinstance(x, self.torch.Tensor) for x in result
4644
):
4745
return [item for tensor in result for item in tensor.flatten().tolist()]
4846
else:
Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,45 @@
1-
from unittest.mock import MagicMock, patch
1+
from unittest.mock import MagicMock, patch, Mock
22

33
import numpy as np
44
import pytest
55
from neo4j_graphrag.embeddings.base import Embedder
66
from neo4j_graphrag.embeddings.sentence_transformers import (
77
SentenceTransformerEmbeddings,
88
)
9+
import torch
910

1011

11-
@patch("neo4j_graphrag.embeddings.sentence_transformers.sentence_transformers")
12-
def test_initialization(MockSentenceTransformer: MagicMock) -> None:
12+
def get_mock_sentence_transformers() -> MagicMock:
13+
mock = MagicMock()
14+
# I know, I know... ¯\_(ツ)_/¯
15+
# This is to cover the if type checks in the embed_query method
16+
mock.Tensor = torch.Tensor
17+
mock.ndarray = np.ndarray
18+
return mock
19+
20+
21+
@patch("builtins.__import__")
22+
def test_initialization(mock_import: Mock) -> None:
23+
MockSentenceTransformer = get_mock_sentence_transformers()
24+
mock_import.return_value = MockSentenceTransformer
1325
instance = SentenceTransformerEmbeddings()
1426
MockSentenceTransformer.SentenceTransformer.assert_called_with("all-MiniLM-L6-v2")
1527
assert isinstance(instance, Embedder)
1628

1729

18-
@patch("neo4j_graphrag.embeddings.sentence_transformers.sentence_transformers")
19-
def test_initialization_with_custom_model(MockSentenceTransformer: MagicMock) -> None:
30+
@patch("builtins.__import__")
31+
def test_initialization_with_custom_model(mock_import: Mock) -> None:
32+
MockSentenceTransformer = get_mock_sentence_transformers()
33+
mock_import.return_value = MockSentenceTransformer
2034
custom_model = "distilbert-base-nli-stsb-mean-tokens"
2135
SentenceTransformerEmbeddings(model=custom_model)
2236
MockSentenceTransformer.SentenceTransformer.assert_called_with(custom_model)
2337

2438

25-
@patch("neo4j_graphrag.embeddings.sentence_transformers.sentence_transformers")
26-
def test_embed_query(MockSentenceTransformer: MagicMock) -> None:
39+
@patch("builtins.__import__")
40+
def test_embed_query(mock_import: Mock) -> None:
41+
MockSentenceTransformer = get_mock_sentence_transformers()
42+
mock_import.return_value = MockSentenceTransformer
2743
mock_model = MockSentenceTransformer.SentenceTransformer.return_value
2844
mock_model.encode.return_value = np.array([[0.1, 0.2, 0.3]])
2945

@@ -35,10 +51,7 @@ def test_embed_query(MockSentenceTransformer: MagicMock) -> None:
3551
assert result == [0.1, 0.2, 0.3]
3652

3753

38-
@patch(
39-
"neo4j_graphrag.embeddings.sentence_transformers.sentence_transformers",
40-
None,
41-
)
42-
def test_import_error() -> None:
54+
@patch("builtins.__import__", side_effect=ImportError)
55+
def test_import_error(mock_import: Mock) -> None:
4356
with pytest.raises(ImportError):
4457
SentenceTransformerEmbeddings()

0 commit comments

Comments
 (0)