Skip to content

Commit 86f88bf

Browse files
authored
Fix imports when some optional dependencies are not installed (#168)
* Fix sentence-transformer embedding import * Fix import when openAI is not installed * Update changelog * Fix for mypy * ruff
1 parent cfeda57 commit 86f88bf

File tree

7 files changed

+106
-55
lines changed

7 files changed

+106
-55
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
## Next
44

5+
### Fixed
6+
- Fix a bug where `openai` Python client and `numpy` were required to import any embedder or LLM.
7+
58
## 1.0.0a1
69

710
## 1.0.0a0

src/neo4j_graphrag/embeddings/openai.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515

1616
from __future__ import annotations
1717

18-
from typing import Any, Type
18+
import abc
19+
from typing import Any
1920

2021
from neo4j_graphrag.embeddings.base import Embedder
2122

@@ -25,26 +26,29 @@
2526
openai = None # type: ignore
2627

2728

28-
class OpenAIEmbeddings(Embedder):
29+
class BaseOpenAIEmbeddings(Embedder, abc.ABC):
30+
def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None:
31+
if openai is None:
32+
raise ImportError(
33+
"Could not import openai python client. "
34+
"Please install it with `pip install openai`."
35+
)
36+
self.model = model
37+
38+
39+
class OpenAIEmbeddings(BaseOpenAIEmbeddings):
2940
"""
3041
OpenAI embeddings class.
3142
This class uses the OpenAI python client to generate embeddings for text data.
3243
3344
Args:
3445
model (str): The name of the OpenAI embedding model to use. Defaults to "text-embedding-ada-002".
46+
kwargs: All other parameters will be passed to the openai.OpenAI init.
3547
"""
3648

37-
client_class: Type[openai.OpenAI] = openai.OpenAI
38-
3949
def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None:
40-
if openai is None:
41-
raise ImportError(
42-
"Could not import openai python client. "
43-
"Please install it with `pip install openai`."
44-
)
45-
46-
self.openai_model = self.client_class(**kwargs)
47-
self.model = model
50+
super().__init__(model, **kwargs)
51+
self.openai_client = openai.OpenAI(**kwargs)
4852

4953
def embed_query(self, text: str, **kwargs: Any) -> list[float]:
5054
"""
@@ -54,11 +58,13 @@ def embed_query(self, text: str, **kwargs: Any) -> list[float]:
5458
text (str): The text to generate an embedding for.
5559
**kwargs (Any): Additional arguments to pass to the OpenAI embedding generation function.
5660
"""
57-
response = self.openai_model.embeddings.create(
61+
response = self.openai_client.embeddings.create(
5862
input=text, model=self.model, **kwargs
5963
)
6064
return response.data[0].embedding
6165

6266

6367
class AzureOpenAIEmbeddings(OpenAIEmbeddings):
64-
client_class: Type[openai.OpenAI] = openai.AzureOpenAI
68+
def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None:
69+
super().__init__(model, **kwargs)
70+
self.openai_client = openai.AzureOpenAI(**kwargs)

src/neo4j_graphrag/embeddings/sentence_transformers.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,13 @@
1515

1616
from typing import Any
1717

18-
import numpy as np
19-
import torch
18+
try:
19+
import numpy as np
20+
import sentence_transformers
21+
import torch
22+
except ImportError:
23+
sentence_transformers = None # type: ignore
24+
2025

2126
from neo4j_graphrag.embeddings.base import Embedder
2227

@@ -25,15 +30,12 @@ class SentenceTransformerEmbeddings(Embedder):
2530
def __init__(
2631
self, model: str = "all-MiniLM-L6-v2", *args: Any, **kwargs: Any
2732
) -> None:
28-
try:
29-
from sentence_transformers import SentenceTransformer
30-
except ImportError as e:
33+
if sentence_transformers is None:
3134
raise ImportError(
3235
"Could not import sentence_transformers python package. "
3336
"Please install it with `pip install sentence-transformers`."
34-
) from e
35-
36-
self.model = SentenceTransformer(model, *args, **kwargs)
37+
)
38+
self.model = sentence_transformers.SentenceTransformer(model, *args, **kwargs)
3739

3840
def embed_query(self, text: str) -> Any:
3941
result = self.model.encode([text])

src/neo4j_graphrag/llm/openai_llm.py

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17-
from typing import Any, Optional, Type
17+
import abc
18+
from typing import Any, Optional
1819

1920
from ..exceptions import LLMGenerationError
2021
from .base import LLMInterface
@@ -26,32 +27,30 @@
2627
openai = None # type: ignore
2728

2829

29-
class OpenAILLM(LLMInterface):
30-
client_class: Type[openai.OpenAI] = openai.OpenAI
31-
async_client_class: Type[openai.AsyncOpenAI] = openai.AsyncOpenAI
30+
class BaseOpenAILLM(LLMInterface, abc.ABC):
31+
client: Any
32+
async_client: Any
3233

3334
def __init__(
3435
self,
3536
model_name: str,
3637
model_params: Optional[dict[str, Any]] = None,
37-
**kwargs: Any,
3838
):
3939
"""
40+
Base class for OpenAI LLM.
41+
42+
Makes sure the openai Python client is installed during init.
4043
4144
Args:
4245
model_name (str):
4346
model_params (str): Parameters like temperature that will be passed to the model when text is sent to it
44-
kwargs: All other parameters will be passed to the openai.OpenAI init.
45-
4647
"""
4748
if openai is None:
4849
raise ImportError(
4950
"Could not import openai Python client. "
5051
"Please install it with `pip install openai`."
5152
)
52-
super().__init__(model_name, model_params, **kwargs)
53-
self.client = self.client_class(**kwargs)
54-
self.async_client = self.async_client_class(**kwargs)
53+
super().__init__(model_name, model_params)
5554

5655
def get_messages(
5756
self,
@@ -76,7 +75,7 @@ def invoke(self, input: str) -> LLMResponse:
7675
"""
7776
try:
7877
response = self.client.chat.completions.create(
79-
messages=self.get_messages(input), # type: ignore
78+
messages=self.get_messages(input),
8079
model=self.model_name,
8180
**self.model_params,
8281
)
@@ -100,7 +99,7 @@ async def ainvoke(self, input: str) -> LLMResponse:
10099
"""
101100
try:
102101
response = await self.async_client.chat.completions.create(
103-
messages=self.get_messages(input), # type: ignore
102+
messages=self.get_messages(input),
104103
model=self.model_name,
105104
**self.model_params,
106105
)
@@ -110,6 +109,42 @@ async def ainvoke(self, input: str) -> LLMResponse:
110109
raise LLMGenerationError(e)
111110

112111

113-
class AzureOpenAILLM(OpenAILLM):
114-
client_class: Type[openai.OpenAI] = openai.AzureOpenAI
115-
async_client_class: Type[openai.AsyncOpenAI] = openai.AsyncAzureOpenAI
112+
class OpenAILLM(BaseOpenAILLM):
113+
def __init__(
114+
self,
115+
model_name: str,
116+
model_params: Optional[dict[str, Any]] = None,
117+
**kwargs: Any,
118+
):
119+
"""OpenAI LLM
120+
121+
Wrapper for the openai Python client LLM.
122+
123+
Args:
124+
model_name (str):
125+
model_params (str): Parameters like temperature that will be passed to the model when text is sent to it
126+
kwargs: All other parameters will be passed to the openai.OpenAI init.
127+
"""
128+
super().__init__(model_name, model_params)
129+
self.client = openai.OpenAI(**kwargs)
130+
self.async_client = openai.AsyncOpenAI(**kwargs)
131+
132+
133+
class AzureOpenAILLM(BaseOpenAILLM):
134+
def __init__(
135+
self,
136+
model_name: str,
137+
model_params: Optional[dict[str, Any]] = None,
138+
**kwargs: Any,
139+
):
140+
"""Azure OpenAI LLM. Use this class when using an OpenAI model
141+
hosted on Microsoft Azure.
142+
143+
Args:
144+
model_name (str):
145+
model_params (str): Parameters like temperature that will be passed to the model when text is sent to it
146+
kwargs: All other parameters will be passed to the openai.OpenAI init.
147+
"""
148+
super().__init__(model_name, model_params)
149+
self.client = openai.AzureOpenAI(**kwargs)
150+
self.async_client = openai.AsyncAzureOpenAI(**kwargs)

tests/unit/embeddings/test_openai_embedder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ def test_openai_embedder_missing_dependency() -> None:
2727
OpenAIEmbeddings()
2828

2929

30-
@patch("neo4j_graphrag.embeddings.openai.OpenAIEmbeddings.client_class")
30+
@patch("neo4j_graphrag.embeddings.openai.openai")
3131
def test_openai_embedder_happy_path(mock_openai: Mock) -> None:
32-
mock_openai.return_value.embeddings.create.return_value = MagicMock(
32+
mock_openai.OpenAI.return_value.embeddings.create.return_value = MagicMock(
3333
data=[MagicMock(embedding=[1.0, 2.0])],
3434
)
3535
embedder = OpenAIEmbeddings(api_key="my key")
@@ -44,9 +44,9 @@ def test_azure_openai_embedder_missing_dependency() -> None:
4444
AzureOpenAIEmbeddings()
4545

4646

47-
@patch("neo4j_graphrag.embeddings.openai.AzureOpenAIEmbeddings.client_class")
47+
@patch("neo4j_graphrag.embeddings.openai.openai")
4848
def test_azure_openai_embedder_happy_path(mock_openai: Mock) -> None:
49-
mock_openai.return_value.embeddings.create.return_value = MagicMock(
49+
mock_openai.AzureOpenAI.return_value.embeddings.create.return_value = MagicMock(
5050
data=[MagicMock(embedding=[1.0, 2.0])],
5151
)
5252
embedder = AzureOpenAIEmbeddings(

tests/unit/embeddings/test_sentence_transformers.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,34 +8,37 @@
88
)
99

1010

11-
@patch("sentence_transformers.SentenceTransformer")
11+
@patch("neo4j_graphrag.embeddings.sentence_transformers.sentence_transformers")
1212
def test_initialization(MockSentenceTransformer: MagicMock) -> None:
1313
instance = SentenceTransformerEmbeddings()
14-
MockSentenceTransformer.assert_called_with("all-MiniLM-L6-v2")
14+
MockSentenceTransformer.SentenceTransformer.assert_called_with("all-MiniLM-L6-v2")
1515
assert isinstance(instance, Embedder)
1616

1717

18-
@patch("sentence_transformers.SentenceTransformer")
18+
@patch("neo4j_graphrag.embeddings.sentence_transformers.sentence_transformers")
1919
def test_initialization_with_custom_model(MockSentenceTransformer: MagicMock) -> None:
2020
custom_model = "distilbert-base-nli-stsb-mean-tokens"
2121
SentenceTransformerEmbeddings(model=custom_model)
22-
MockSentenceTransformer.assert_called_with(custom_model)
22+
MockSentenceTransformer.SentenceTransformer.assert_called_with(custom_model)
2323

2424

25-
@patch("sentence_transformers.SentenceTransformer")
25+
@patch("neo4j_graphrag.embeddings.sentence_transformers.sentence_transformers")
2626
def test_embed_query(MockSentenceTransformer: MagicMock) -> None:
27-
mock_model = MockSentenceTransformer.return_value
27+
mock_model = MockSentenceTransformer.SentenceTransformer.return_value
2828
mock_model.encode.return_value = np.array([[0.1, 0.2, 0.3]])
2929

3030
instance = SentenceTransformerEmbeddings()
3131
result = instance.embed_query("test query")
3232

3333
mock_model.encode.assert_called_with(["test query"])
34-
assert result == [0.1, 0.2, 0.3]
3534
assert isinstance(result, list)
35+
assert result == [0.1, 0.2, 0.3]
3636

3737

38-
@patch("sentence_transformers.SentenceTransformer", side_effect=ImportError)
39-
def test_import_error(MockSentenceTransformer: MagicMock) -> None:
38+
@patch(
39+
"neo4j_graphrag.embeddings.sentence_transformers.sentence_transformers",
40+
None,
41+
)
42+
def test_import_error() -> None:
4043
with pytest.raises(ImportError):
4144
SentenceTransformerEmbeddings()

tests/unit/llm/test_openai_llm.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ def test_openai_llm_missing_dependency() -> None:
2525
OpenAILLM(model_name="gpt-4o")
2626

2727

28-
@patch("neo4j_graphrag.llm.openai_llm.OpenAILLM.client_class")
28+
@patch("neo4j_graphrag.llm.openai_llm.openai")
2929
def test_openai_llm_happy_path(mock_openai: Mock) -> None:
30-
mock_openai.return_value.chat.completions.create.return_value = MagicMock(
30+
mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock(
3131
choices=[MagicMock(message=MagicMock(content="openai chat response"))],
3232
)
3333
llm = OpenAILLM(api_key="my key", model_name="gpt")
@@ -42,10 +42,12 @@ def test_azure_openai_llm_missing_dependency() -> None:
4242
AzureOpenAILLM(model_name="gpt-4o")
4343

4444

45-
@patch("neo4j_graphrag.llm.openai_llm.AzureOpenAILLM.client_class")
45+
@patch("neo4j_graphrag.llm.openai_llm.openai")
4646
def test_azure_openai_llm_happy_path(mock_openai: Mock) -> None:
47-
mock_openai.return_value.chat.completions.create.return_value = MagicMock(
48-
choices=[MagicMock(message=MagicMock(content="openai chat response"))],
47+
mock_openai.AzureOpenAI.return_value.chat.completions.create.return_value = (
48+
MagicMock(
49+
choices=[MagicMock(message=MagicMock(content="openai chat response"))],
50+
)
4951
)
5052
llm = AzureOpenAILLM(
5153
model_name="gpt",

0 commit comments

Comments
 (0)