1
- from unittest .mock import MagicMock , patch
1
+ from unittest .mock import MagicMock , patch , Mock
2
2
3
3
import numpy as np
4
4
import pytest
5
5
from neo4j_graphrag .embeddings .base import Embedder
6
6
from neo4j_graphrag .embeddings .sentence_transformers import (
7
7
SentenceTransformerEmbeddings ,
8
8
)
9
+ import torch
9
10
10
11
11
- @patch ("neo4j_graphrag.embeddings.sentence_transformers.sentence_transformers" )
12
- def test_initialization (MockSentenceTransformer : MagicMock ) -> None :
12
+ def get_mock_sentence_transformers () -> MagicMock :
13
+ mock = MagicMock ()
14
+ # I know, I know... ¯\_(ツ)_/¯
15
+ # This is to cover the if type checks in the embed_query method
16
+ mock .Tensor = torch .Tensor
17
+ mock .ndarray = np .ndarray
18
+ return mock
19
+
20
+
21
+ @patch ("builtins.__import__" )
22
+ def test_initialization (mock_import : Mock ) -> None :
23
+ MockSentenceTransformer = get_mock_sentence_transformers ()
24
+ mock_import .return_value = MockSentenceTransformer
13
25
instance = SentenceTransformerEmbeddings ()
14
26
MockSentenceTransformer .SentenceTransformer .assert_called_with ("all-MiniLM-L6-v2" )
15
27
assert isinstance (instance , Embedder )
16
28
17
29
18
- @patch ("neo4j_graphrag.embeddings.sentence_transformers.sentence_transformers" )
19
- def test_initialization_with_custom_model (MockSentenceTransformer : MagicMock ) -> None :
30
+ @patch ("builtins.__import__" )
31
+ def test_initialization_with_custom_model (mock_import : Mock ) -> None :
32
+ MockSentenceTransformer = get_mock_sentence_transformers ()
33
+ mock_import .return_value = MockSentenceTransformer
20
34
custom_model = "distilbert-base-nli-stsb-mean-tokens"
21
35
SentenceTransformerEmbeddings (model = custom_model )
22
36
MockSentenceTransformer .SentenceTransformer .assert_called_with (custom_model )
23
37
24
38
25
- @patch ("neo4j_graphrag.embeddings.sentence_transformers.sentence_transformers" )
26
- def test_embed_query (MockSentenceTransformer : MagicMock ) -> None :
39
+ @patch ("builtins.__import__" )
40
+ def test_embed_query (mock_import : Mock ) -> None :
41
+ MockSentenceTransformer = get_mock_sentence_transformers ()
42
+ mock_import .return_value = MockSentenceTransformer
27
43
mock_model = MockSentenceTransformer .SentenceTransformer .return_value
28
44
mock_model .encode .return_value = np .array ([[0.1 , 0.2 , 0.3 ]])
29
45
@@ -35,10 +51,7 @@ def test_embed_query(MockSentenceTransformer: MagicMock) -> None:
35
51
assert result == [0.1 , 0.2 , 0.3 ]
36
52
37
53
38
- @patch (
39
- "neo4j_graphrag.embeddings.sentence_transformers.sentence_transformers" ,
40
- None ,
41
- )
42
- def test_import_error () -> None :
54
+ @patch ("builtins.__import__" , side_effect = ImportError )
55
+ def test_import_error (mock_import : Mock ) -> None :
43
56
with pytest .raises (ImportError ):
44
57
SentenceTransformerEmbeddings ()
0 commit comments