Skip to content

Commit d1cef28

Browse files
authored
Added Vertex AI LLM class (#141)
* Added Vertex AI LLM class * Updated docstrings * Updated unit test workflow * Removed duplicate poetry install for pr workflow * Removed --no-root from poetry install in pr.yaml * Fixed typo * Updated docs * Updated CHANGELOG for previous PR * Updated CHANGELOG * Fixed typo
1 parent bf65ddd commit d1cef28

File tree

9 files changed

+182
-14
lines changed

9 files changed

+182
-14
lines changed

.github/workflows/pr.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,7 @@ jobs:
3232
key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
3333
- name: Install dependencies
3434
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
35-
run: poetry install --no-interaction --no-root
36-
- name: Install root project
37-
run: poetry install --no-interaction
35+
run: poetry install --no-interaction --extras external_clients
3836
- name: Check format and linting
3937
run: |
4038
poetry run ruff check --select I .

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99
- Fix bug in `Text2CypherRetriever` using `custom_prompt` arg where the `search` method would not inject the `query_text` content.
1010
- Add feature to include kwargs in `Text2CypherRetriever.search()` that will be injected into a custom prompt, if provided.
1111
- Add validation to `custom_prompt` parameter of `Text2CypherRetriever` to ensure that `query_text` placeholder exists in prompt.
12+
- Introduced a fixed size text splitter component for splitting text into specified fixed size chunks with overlap. Updated examples and tests to utilize this new component.
13+
- Introduced Vertex AI LLM class for integrating Vertex AI models.
14+
- Added unit tests for the Vertex AI LLM class.
15+
16+
### Fixed
17+
- Resolved import issue with the Vertex AI Embeddings class.
1218

1319
### Changed
1420
- Moved the Embedder class to the neo4j_graphrag.embeddings directory for better organization alongside other custom embedders.

docs/source/api.rst

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,16 @@ LLMInterface
174174

175175

176176
OpenAILLM
177-
---------
177+
=========
178178

179179
.. autoclass:: neo4j_graphrag.llm.OpenAILLM
180180
:members:
181181

182+
VertexAILLM
183+
===========
184+
185+
.. autoclass:: neo4j_graphrag.llm.vertexai.VertexAILLM
186+
:members:
182187

183188
PromptTemplate
184189
==============
@@ -389,4 +394,4 @@ PipelineStatusUpdateError
389394
=========================
390395

391396
.. autoclass:: neo4j_graphrag.experimental.pipeline.exceptions.PipelineStatusUpdateError
392-
:show-inheritance:
397+
:show-inheritance:

src/neo4j_graphrag/embeddings/vertexai.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
1615
from __future__ import annotations
1716

1817
from typing import Any
@@ -22,10 +21,8 @@
2221
try:
2322
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
2423
except ImportError:
25-
raise ImportError(
26-
"Could not import Vertex AI python client. "
27-
"Please install it with `pip install google-cloud-aiplatform`."
28-
)
24+
TextEmbeddingInput = None
25+
TextEmbeddingModel = None
2926

3027

3128
class VertexAIEmbeddings(Embedder):
@@ -38,6 +35,11 @@ class VertexAIEmbeddings(Embedder):
3835
"""
3936

4037
def __init__(self, model: str = "text-embedding-004") -> None:
38+
if TextEmbeddingInput is None or TextEmbeddingInput is None:
39+
raise ImportError(
40+
"Could not import Vertex AI Python client. "
41+
"Please install it with `pip install google-cloud-aiplatform`."
42+
)
4143
self.vertexai_model = TextEmbeddingModel.from_pretrained(model)
4244

4345
def embed_query(

src/neo4j_graphrag/llm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616
from .openai_llm import OpenAILLM
1717
from .types import LLMResponse
1818

19-
__all__ = ["LLMResponse", "LLMInterface", "OpenAILLM"]
19+
__all__ = ["LLMResponse", "LLMInterface", "OpenAILLM", "VertexAILLM"]

src/neo4j_graphrag/llm/base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,13 @@
2121

2222

2323
class LLMInterface(ABC):
24-
"""Interface for large language models."""
24+
"""Interface for large language models.
25+
26+
Args:
27+
model_name (str): The name of the language model.
28+
model_params (Optional[dict], optional): Additional parameters passed to the model when text is sent to it. Defaults to None.
29+
**kwargs (Any): Arguments passed to the model when for the class is initialised. Defaults to None.
30+
"""
2531

2632
def __init__(
2733
self,

src/neo4j_graphrag/llm/openai_llm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@ def __init__(
3737
3838
Args:
3939
model_name (str):
40-
model_params (str): Parameters like temperature and such that will be
41-
passed to the model
40+
model_params (str): Parameters like temperature that will be passed to the model when text is sent to it
4241
kwargs: All other parameters will be passed to the openai.OpenAI init.
4342
4443
"""

src/neo4j_graphrag/llm/vertexai.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Neo4j Sweden AB [https://neo4j.com]
2+
# #
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
# #
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
# #
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
from typing import Any, Optional
17+
18+
from neo4j_graphrag.exceptions import LLMGenerationError
19+
from neo4j_graphrag.llm.base import LLMInterface
20+
from neo4j_graphrag.llm.types import LLMResponse
21+
22+
try:
23+
from vertexai.generative_models import GenerativeModel, ResponseValidationError
24+
except ImportError:
25+
GenerativeModel = None
26+
ResponseValidationError = None
27+
28+
29+
class VertexAILLM(LLMInterface):
30+
"""Interface for large language models on Vertex AI
31+
32+
Args:
33+
model_name (str, optional): Name of the LLM to use. Defaults to "gemini-1.5-flash-001".
34+
model_params (Optional[dict], optional): Additional parameters passed to the model when text is sent to it. Defaults to None.
35+
**kwargs (Any): Arguments passed to the model when for the class is initialised. Defaults to None.
36+
37+
Raises:
38+
LLMGenerationError: If there's an error generating the response from the model.
39+
40+
Example:
41+
42+
.. code-block:: python
43+
44+
from neo4j_graphrag.llm import VertexAILLM
45+
from vertexai.generative_models import GenerationConfig
46+
47+
generation_config = GenerationConfig(temperature=0.0)
48+
llm = VertexAILLM(
49+
model_name="gemini-1.5-flash-001", generation_config=generation_config
50+
)
51+
llm.invoke("Who is the mother of Paul Atreides?")
52+
"""
53+
54+
def __init__(
55+
self,
56+
model_name: str = "gemini-1.5-flash-001",
57+
model_params: Optional[dict[str, Any]] = None,
58+
**kwargs: Any,
59+
):
60+
if GenerativeModel is None or ResponseValidationError is None:
61+
raise ImportError(
62+
"Could not import Vertex AI Python client. "
63+
"Please install it with `pip install google-cloud-aiplatform`."
64+
)
65+
super().__init__(model_name, model_params)
66+
self.model = GenerativeModel(model_name=model_name, **kwargs)
67+
68+
def invoke(self, input: str) -> LLMResponse:
69+
"""Sends text to the LLM and returns a response.
70+
71+
Args:
72+
input (str): The text to send to the LLM.
73+
74+
Returns:
75+
LLMResponse: The response from the LLM.
76+
"""
77+
try:
78+
response = self.model.generate_content(input, **self.model_params)
79+
return LLMResponse(content=response.text)
80+
except ResponseValidationError as e:
81+
raise LLMGenerationError(e)
82+
83+
async def ainvoke(self, input: str) -> LLMResponse:
84+
"""Asynchronously sends text to the LLM and returns a response.
85+
86+
Args:
87+
input (str): The text to send to the LLM.
88+
89+
Returns:
90+
LLMResponse: The response from the LLM.
91+
"""
92+
try:
93+
response = await self.model.generate_content_async(
94+
input, **self.model_params
95+
)
96+
return LLMResponse(content=response.text)
97+
except ResponseValidationError as e:
98+
raise LLMGenerationError(e)

tests/unit/llm/test_vertexai_llm.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Neo4j Sweden AB [https://neo4j.com]
2+
# #
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
# #
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
# #
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
from unittest.mock import AsyncMock, MagicMock, Mock, patch
17+
18+
import pytest
19+
from neo4j_graphrag.llm.vertexai import VertexAILLM
20+
21+
22+
@patch("neo4j_graphrag.llm.vertexai.GenerativeModel", None)
23+
def test_vertexai_llm_missing_dependency() -> None:
24+
with pytest.raises(ImportError):
25+
VertexAILLM(model_name="gemini-1.5-flash-001")
26+
27+
28+
@patch("neo4j_graphrag.llm.vertexai.GenerativeModel")
29+
def test_invoke_happy_path(GenerativeModelMock: MagicMock) -> None:
30+
mock_response = Mock()
31+
mock_response.text = "Return text"
32+
mock_model = GenerativeModelMock.return_value
33+
mock_model.generate_content.return_value = mock_response
34+
model_params = {"temperature": 0.5}
35+
llm = VertexAILLM("gemini-1.5-flash-001", model_params)
36+
input_text = "may thy knife chip and shatter"
37+
response = llm.invoke(input_text)
38+
assert response.content == "Return text"
39+
llm.model.generate_content.assert_called_once_with(input_text, **model_params)
40+
41+
42+
@pytest.mark.asyncio
43+
@patch("neo4j_graphrag.llm.vertexai.GenerativeModel")
44+
async def test_ainvoke_happy_path(GenerativeModelMock: MagicMock) -> None:
45+
mock_response = AsyncMock()
46+
mock_response.text = "Return text"
47+
mock_model = GenerativeModelMock.return_value
48+
mock_model.generate_content_async = AsyncMock(return_value=mock_response)
49+
model_params = {"temperature": 0.5}
50+
llm = VertexAILLM("gemini-1.5-flash-001", model_params)
51+
input_text = "may thy knife chip and shatter"
52+
response = await llm.ainvoke(input_text)
53+
assert response.content == "Return text"
54+
llm.model.generate_content_async.assert_called_once_with(input_text, **model_params)

0 commit comments

Comments
 (0)