Skip to content

Commit bc8540e

Browse files
authored
Improve import system (#175)
* Improve import system for cohere * Improve import system for openai * Improve import system for anthropic * Improve import system for sentence_transformers * Fix async tests to support older Python versions * Add tox for multiversion testing * Reorganise dependencies * Add qdrant to extras * Move llama-index to experimental group * Update lock file
1 parent 57529d4 commit bc8540e

17 files changed

+861
-1065
lines changed

.github/workflows/pr-e2e-tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ jobs:
7777
key: ${{ runner.os }}-venv-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
7878
- name: Install dependencies
7979
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
80-
run: poetry install --no-interaction --no-cache --with dev
80+
run: poetry install --no-interaction --no-cache --with dev --all-extras
8181
- name: Clear Poetry cache
8282
run: poetry cache clear --all .
8383
- name: Show disk usage after Poetry installation

.github/workflows/pr.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
3333
- name: Install dependencies
3434
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
35-
run: poetry install --no-interaction
35+
run: poetry install --no-interaction --all-extras
3636
- name: Check format and linting
3737
run: |
3838
poetry run ruff check .

.github/workflows/scheduled-e2e-tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ jobs:
8585
key: ${{ runner.os }}-venv-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
8686
- name: Install dependencies
8787
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
88-
run: poetry install --no-interaction --no-cache --with dev
88+
run: poetry install --no-interaction --no-cache --with dev --all-extras
8989
- name: Clear Poetry cache
9090
run: poetry cache clear --all .
9191
- name: Show disk usage after Poetry installation

docs/source/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ Install dependencies
194194

195195
.. code:: bash
196196
197-
poetry install
197+
poetry install --all-extras
198198
199199
***************
200200
Getting started

poetry.lock

Lines changed: 665 additions & 909 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 27 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -31,62 +31,49 @@ from = "src"
3131
python = "^3.9.0"
3232
neo4j = "^5.17.0"
3333
pydantic = "^2.6.3"
34-
urllib3 = "<2"
35-
weaviate-client = {version = "^4.6.1", optional = true}
36-
pinecone-client = {version = "^4.1.0", optional = true}
37-
types-mock = "^5.1.0.20240425"
38-
eval-type-backport = "^0.2.0"
39-
pypdf = "^4.3.1"
40-
fsspec = "^2024.9.0"
34+
fsspec = {version = "^2024.9.0", optional = true}
35+
langchain-text-splitters = {version = "^0.3.0", optional = true }
36+
pypdf = {version = "^4.3.1", optional = true}
4137
pygraphviz = [
4238
{version = "^1.13.0", python = ">=3.10,<4.0.0", optional = true},
4339
{version = "^1.0.0", python = "<3.10", optional = true}
4440
]
45-
google-cloud-aiplatform = {version = "^1.66.0", optional = true}
41+
weaviate-client = {version = "^4.6.1", optional = true }
42+
pinecone-client = {version = "^4.1.0", optional = true }
43+
google-cloud-aiplatform = {version = "^1.66.0", optional = true }
4644
cohere = {version = "^5.9.0", optional = true}
47-
anthropic = { version = "^0.34.2", optional = true}
4845
mistralai = {version = "^1.0.3", optional = true}
4946
qdrant-client = {version = "^1.11.3", optional = true}
47+
llama-index = {version = "^0.10.55", optional = true }
48+
openai = {version = "^1.51.1", optional = true }
49+
anthropic = { version = "^0.36.0", optional = true}
50+
sentence-transformers = {version = "^3.0.0", optional = true }
5051

5152
[tool.poetry.group.dev.dependencies]
52-
pylint = "^3.1.0"
53+
urllib3 = "<2"
54+
ruff = "^0.3.0"
5355
mypy = "^1.10.0"
5456
pytest = "^8.0.2"
55-
pytest-mock = "^3.12.0"
56-
pre-commit = { version = "^3.6.2", python = "^3.9" }
5757
coverage = "^7.4.3"
58-
ruff = "^0.3.0"
59-
langchain-text-splitters = "^0.3.0"
60-
weaviate-client = "^4.6.1"
61-
sentence-transformers = "^3.0.0"
62-
pinecone-client = "^4.1.0"
63-
requests = "^2.32.0"
64-
sphinx = { version = "^7.2.6", python = "^3.9" }
65-
tox = "^4.15.1"
66-
numpy = [
67-
{version = "^1.24.0", python = "<3.12"},
68-
{version = "^1.26.0", python = ">=3.12"}
69-
]
70-
scipy = [
71-
{version = "^1", python = "<3.12"},
72-
{version = "^1.7.0", python = ">=3.12"}
73-
]
74-
llama-index = "^0.10.55"
7558
pytest-asyncio = "^0.23.8"
76-
pygraphviz = [
77-
{version = "^1.13.0", python = ">=3.10,<4.0.0"},
78-
{version = "^1.0.0", python = "<3.10"}
79-
]
80-
google-cloud-aiplatform = {version = "^1.66.0"}
81-
cohere = {version = "^5.9.0"}
82-
anthropic = { version = "^0.34.2"}
83-
mistralai = {version = "^1.0.3"}
84-
qdrant-client = {version = "^1.11.3"}
85-
langchain-openai = "^0.2.2" # needed in the examples
59+
pre-commit = { version = "^3.6.2", python = "^3.9" }
60+
sphinx = { version = "^7.2.6", python = "^3.9" }
61+
langchain-openai = {version = "^0.2.2", optional = true }
62+
langchain-huggingface = {version = "^0.1.0", optional = true }
8663

8764
[tool.poetry.extras]
88-
external_clients = ["weaviate-client", "pinecone-client", "google-cloud-aiplatform", "cohere", "anthropic", "mistralai", "qdrant-client"]
65+
weaviate = ["weaviate-client"]
66+
pinecone = ["pinecone-client"]
67+
google = ["google-cloud-aiplatform"]
68+
cohere = ["cohere"]
69+
anthropic = ["anthropic"]
70+
openai = ["openai"]
71+
mistralai = ["mistralai"]
72+
qdrant = ["qdrant-client"]
8973
kg_creation_tools = ["pygraphviz"]
74+
sentence-transformers = ["sentence-transformers"]
75+
experimental = ["pypdf", "fsspec", "langchain-text-splitters", "pygraphviz", "llama-index"]
76+
examples = ["langchain-openai", "langchain-huggingface"]
9077

9178
[build-system]
9279
requires = ["poetry-core>=1.0.0"]
@@ -101,9 +88,6 @@ filterwarnings = [
10188
[tool.coverage.paths]
10289
source = ["src"]
10390

104-
[tool.pylint."MESSAGES CONTROL"]
105-
disable="C0114,C0115"
106-
10791
[tool.mypy]
10892
strict = true
10993
ignore_missing_imports = true

src/neo4j_graphrag/embeddings/openai.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,28 +15,11 @@
1515

1616
from __future__ import annotations
1717

18-
import abc
1918
from typing import Any
20-
2119
from neo4j_graphrag.embeddings.base import Embedder
2220

23-
try:
24-
import openai
25-
except ImportError:
26-
openai = None # type: ignore
27-
28-
29-
class BaseOpenAIEmbeddings(Embedder, abc.ABC):
30-
def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None:
31-
if openai is None:
32-
raise ImportError(
33-
"Could not import openai python client. "
34-
"Please install it with `pip install openai`."
35-
)
36-
self.model = model
37-
3821

39-
class OpenAIEmbeddings(BaseOpenAIEmbeddings):
22+
class OpenAIEmbeddings(Embedder):
4023
"""
4124
OpenAI embeddings class.
4225
This class uses the OpenAI python client to generate embeddings for text data.
@@ -47,8 +30,16 @@ class OpenAIEmbeddings(BaseOpenAIEmbeddings):
4730
"""
4831

4932
def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None:
50-
super().__init__(model, **kwargs)
51-
self.openai_client = openai.OpenAI(**kwargs)
33+
try:
34+
import openai
35+
except ImportError:
36+
raise ImportError(
37+
"Could not import openai python client. "
38+
"Please install it with `pip install openai`."
39+
)
40+
self.openai = openai
41+
self.model = model
42+
self.openai_client = self.openai.OpenAI(**kwargs)
5243

5344
def embed_query(self, text: str, **kwargs: Any) -> list[float]:
5445
"""
@@ -67,4 +58,4 @@ def embed_query(self, text: str, **kwargs: Any) -> list[float]:
6758
class AzureOpenAIEmbeddings(OpenAIEmbeddings):
6859
def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None:
6960
super().__init__(model, **kwargs)
70-
self.openai_client = openai.AzureOpenAI(**kwargs)
61+
self.openai_client = self.openai.AzureOpenAI(**kwargs)

src/neo4j_graphrag/embeddings/sentence_transformers.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,34 +15,32 @@
1515

1616
from typing import Any
1717

18-
try:
19-
import numpy as np
20-
import sentence_transformers
21-
import torch
22-
except ImportError:
23-
sentence_transformers = None # type: ignore
24-
25-
2618
from neo4j_graphrag.embeddings.base import Embedder
2719

2820

2921
class SentenceTransformerEmbeddings(Embedder):
3022
def __init__(
3123
self, model: str = "all-MiniLM-L6-v2", *args: Any, **kwargs: Any
3224
) -> None:
33-
if sentence_transformers is None:
25+
try:
26+
import numpy as np
27+
import sentence_transformers
28+
import torch
29+
except ImportError:
3430
raise ImportError(
3531
"Could not import sentence_transformers python package. "
3632
"Please install it with `pip install sentence-transformers`."
3733
)
34+
self.torch = torch
35+
self.np = np
3836
self.model = sentence_transformers.SentenceTransformer(model, *args, **kwargs)
3937

4038
def embed_query(self, text: str) -> Any:
4139
result = self.model.encode([text])
42-
if isinstance(result, torch.Tensor) or isinstance(result, np.ndarray):
40+
if isinstance(result, self.torch.Tensor) or isinstance(result, self.np.ndarray):
4341
return result.flatten().tolist()
4442
elif isinstance(result, list) and all(
45-
isinstance(x, torch.Tensor) for x in result
43+
isinstance(x, self.torch.Tensor) for x in result
4644
):
4745
return [item for tensor in result for item in tensor.flatten().tolist()]
4846
else:

src/neo4j_graphrag/llm/anthropic_llm.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,6 @@
1919
from neo4j_graphrag.llm.base import LLMInterface
2020
from neo4j_graphrag.llm.types import LLMResponse
2121

22-
try:
23-
import anthropic
24-
from anthropic import APIError
25-
except ImportError:
26-
anthropic = None # type: ignore
27-
APIError = None # type: ignore
28-
2922

3023
class AnthropicLLM(LLMInterface):
3124
"""Interface for large language models on Anthropic
@@ -58,12 +51,15 @@ def __init__(
5851
model_params: Optional[dict[str, Any]] = None,
5952
**kwargs: Any,
6053
):
61-
if anthropic is None:
54+
try:
55+
import anthropic
56+
except ImportError:
6257
raise ImportError(
6358
"Could not import Anthropic Python client. "
6459
"Please install it with `pip install anthropic`."
6560
)
6661
super().__init__(model_name, model_params)
62+
self.anthropic = anthropic
6763
self.client = anthropic.Anthropic(**kwargs)
6864
self.async_client = anthropic.AsyncAnthropic(**kwargs)
6965

@@ -88,7 +84,7 @@ def invoke(self, input: str) -> LLMResponse:
8884
**self.model_params,
8985
)
9086
return LLMResponse(content=response.content)
91-
except APIError as e:
87+
except self.anthropic.APIError as e:
9288
raise LLMGenerationError(e)
9389

9490
async def ainvoke(self, input: str) -> LLMResponse:
@@ -112,5 +108,5 @@ async def ainvoke(self, input: str) -> LLMResponse:
112108
**self.model_params,
113109
)
114110
return LLMResponse(content=response.content)
115-
except APIError as e:
111+
except self.anthropic.APIError as e:
116112
raise LLMGenerationError(e)

src/neo4j_graphrag/llm/cohere_llm.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,6 @@
2020
from neo4j_graphrag.llm.base import LLMInterface
2121
from neo4j_graphrag.llm.types import LLMResponse
2222

23-
try:
24-
import cohere
25-
from cohere.core import ApiError
26-
except ImportError:
27-
cohere = None # type: ignore
28-
ApiError = Exception # type: ignore[assignment, misc]
29-
3023

3124
class CohereLLM(LLMInterface):
3225
"""Interface for large language models on the Cohere platform
@@ -55,12 +48,18 @@ def __init__(
5548
model_params: Optional[dict[str, Any]] = None,
5649
**kwargs: Any,
5750
) -> None:
58-
if cohere is None:
51+
super().__init__(model_name, model_params)
52+
try:
53+
import cohere
54+
except ImportError:
5955
raise ImportError(
6056
"Could not import cohere python client. "
6157
"Please install it with `pip install cohere`."
6258
)
63-
super().__init__(model_name, model_params)
59+
60+
self.cohere = cohere
61+
self.cohere_api_error = cohere.core.api_error.ApiError
62+
6463
self.client = cohere.Client(**kwargs)
6564
self.async_client = cohere.AsyncClient(**kwargs)
6665

@@ -78,7 +77,7 @@ def invoke(self, input: str) -> LLMResponse:
7877
message=input,
7978
model=self.model_name,
8079
)
81-
except ApiError as e:
80+
except self.cohere_api_error as e:
8281
raise LLMGenerationError(e)
8382
return LLMResponse(
8483
content=res.text,
@@ -98,7 +97,7 @@ async def ainvoke(self, input: str) -> LLMResponse:
9897
message=input,
9998
model=self.model_name,
10099
)
101-
except ApiError as e:
100+
except self.cohere_api_error as e:
102101
raise LLMGenerationError(e)
103102
return LLMResponse(
104103
content=res.text,

0 commit comments

Comments
 (0)