Skip to content

Commit 89411ca

Browse files
authored
Support for Cohere embeddings and LLM (#147)
* Support Cohere embeddings and LLM * Fix tests * Ruff * Fix dependencies * Fix mock * Add test for cohere LLM failure * Rename file for consistency * Ruff * Mypy * Fix tests * Documentation & CHANGELOG * Recreate lock file after merge
1 parent 68a49d7 commit 89411ca

File tree

9 files changed

+770
-248
lines changed

9 files changed

+770
-248
lines changed

CHANGELOG.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@
44

55
### Added
66
- Added AzureOpenAILLM and AzureOpenAIEmbeddings to support Azure served OpenAI models
7-
- Add `template` validation in `PromptTemplate` class upon construction.
7+
- Added `template` validation in `PromptTemplate` class upon construction.
88
- `custom_prompt` arg is now converted to `Text2CypherTemplate` class within the `Text2CypherRetriever.get_search_results` method.
99
- `Text2CypherTemplate` and `RAGTemplate` prompt templates now require `query_text` arg and will error if it is not present. Previous `query_text` aliases may be used, but will warn of deprecation.
10-
- Fix bug in `Text2CypherRetriever` using `custom_prompt` arg where the `search` method would not inject the `query_text` content.
11-
- Add feature to include kwargs in `Text2CypherRetriever.search()` that will be injected into a custom prompt, if provided.
12-
- Add validation to `custom_prompt` parameter of `Text2CypherRetriever` to ensure that `query_text` placeholder exists in prompt.
10+
- Fixed bug in `Text2CypherRetriever` using `custom_prompt` arg where the `search` method would not inject the `query_text` content.
11+
- Added feature to include kwargs in `Text2CypherRetriever.search()` that will be injected into a custom prompt, if provided.
12+
- Added validation to `custom_prompt` parameter of `Text2CypherRetriever` to ensure that `query_text` placeholder exists in prompt.
1313
- Introduced a fixed size text splitter component for splitting text into specified fixed size chunks with overlap. Updated examples and tests to utilize this new component.
1414
- Introduced Vertex AI LLM class for integrating Vertex AI models.
1515
- Added unit tests for the Vertex AI LLM class.
16+
- Added support for Cohere LLM and embeddings - added optional dependency to `cohere`.
1617

1718
### Fixed
1819
- Resolved import issue with the Vertex AI Embeddings class.
@@ -31,6 +32,9 @@
3132
- Updated documentation to include OpenAI and Vertex AI embeddings classes.
3233
- Added google-cloud-aiplatform as an optional dependency for Vertex AI embeddings.
3334

35+
### Fixed
36+
- Make `pygraphviz` an optional dependency - it is now only required when calling `pipeline.draw`.
37+
3438
## 0.6.2
3539

3640
### Fixed

docs/source/api.rst

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,12 @@ VertexAIEmbeddings
162162
.. autoclass:: neo4j_graphrag.embeddings.vertexai.VertexAIEmbeddings
163163
:members:
164164

165+
CohereEmbeddings
166+
================
167+
168+
.. autoclass:: neo4j_graphrag.embeddings.cohere.CohereEmbeddings
169+
:members:
170+
165171
**********
166172
Generation
167173
**********
@@ -174,17 +180,26 @@ LLMInterface
174180

175181

176182
OpenAILLM
177-
=========
183+
---------
178184

179-
.. autoclass:: neo4j_graphrag.llm.openai.OpenAILLM
185+
.. autoclass:: neo4j_graphrag.llm.openai_llm.OpenAILLM
180186
:members:
187+
:undoc-members: get_messages, client_class, async_client_class
188+
181189

182190
VertexAILLM
183-
===========
191+
-----------
184192

185-
.. autoclass:: neo4j_graphrag.llm.vertexai.VertexAILLM
193+
.. autoclass:: neo4j_graphrag.llm.vertexai_llm.VertexAILLM
186194
:members:
187195

196+
CohereLLM
197+
---------
198+
199+
.. autoclass:: neo4j_graphrag.llm.cohere_llm.CohereLLM
200+
:members:
201+
202+
188203
PromptTemplate
189204
==============
190205

docs/source/user_guide_rag.rst

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ In practice, it's done with only a few lines of code:
2323
2424
from neo4j import GraphDatabase
2525
from neo4j_graphrag.retrievers import VectorRetriever
26-
from neo4j_graphrag.llm.openai import OpenAILLM
26+
from neo4j_graphrag.llm import OpenAILLM
2727
from neo4j_graphrag.generation import GraphRAG
28-
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
28+
from neo4j_graphrag.embeddings import OpenAIEmbeddings
2929
3030
# 1. Neo4j driver
3131
URI = "neo4j://localhost:7687"
@@ -56,6 +56,12 @@ In practice, it's done with only a few lines of code:
5656
print(response.answer)
5757
5858
59+
.. note::
60+
61+
In order to run this code, the `openai` Python package needs to be installed:
62+
`pip install openai`
63+
64+
5965
The following sections provide more details about how to customize this code.
6066

6167
******************************
@@ -70,6 +76,8 @@ Using Another LLM Model
7076
If OpenAI cannot be used directly, there are a few available alternatives:
7177

7278
- Use Azure OpenAI.
79+
- Use Google VertexAI.
80+
- Use Cohere.
7381
- Use a local Ollama model.
7482
- Implement a custom interface.
7583
- Utilize any LangChain chat model.
@@ -83,18 +91,46 @@ It is possible to use Azure OpenAI switching to the `AzureOpenAILLM` class:
8391

8492
.. code:: python
8593
86-
from neo4j_graphrag.llm.openai import AzureOpenAILLM
94+
from neo4j_graphrag.llm import AzureOpenAILLM
8795
llm = AzureOpenAILLM(
8896
model_name="gpt-4o",
8997
azure_endpoint="https://example-endpoint.openai.azure.com/", # update with your endpoint
9098
api_version="2024-06-01", # update appropriate version
91-
api_key="ba3a46d86259405385f73f08078f588b", # api_key is optional and can also be set with OPENAI_API_KEY env var
99+
api_key="...", # api_key is optional and can also be set with OPENAI_API_KEY env var
92100
)
93101
llm.invoke("say something")
94102
95103
Check the OpenAI Python client [documentation](https://github.com/openai/openai-python?tab=readme-ov-file#microsoft-azure-openai)
96104
to learn more about the configuration.
97105

106+
.. note::
107+
108+
In order to run this code, the `openai` Python package needs to be installed:
109+
`pip install openai`
110+
111+
112+
Using VertexAI LLM
113+
------------------
114+
115+
To use VertexAI, instantiate the `VertexAILLM` class:
116+
117+
.. code:: python
118+
119+
from neo4j_graphrag.llm import VertexAILLM
120+
from vertexai.generative_models import GenerationConfig
121+
122+
generation_config = GenerationConfig(temperature=0.0)
123+
llm = VertexAILLM(
124+
model_name="gemini-1.5-flash-001", generation_config=generation_config
125+
)
126+
llm.invoke("say something")
127+
128+
129+
.. note::
130+
131+
In order to run this code, the `google-cloud-aiplatform` Python package needs to be installed:
132+
`pip install google-cloud-aiplatform`
133+
98134

99135
Using a Local Model via Ollama
100136
-------------------------------
@@ -105,7 +141,7 @@ it can be queried using the following:
105141

106142
.. code:: python
107143
108-
from neo4j_graphrag.llm.openai import OpenAILLM
144+
from neo4j_graphrag.llm import OpenAILLM
109145
llm = OpenAILLM(api_key="ollama", base_url="http://127.0.0.1:11434/v1", model_name="orca-mini")
110146
llm.invoke("say something")
111147
@@ -300,13 +336,18 @@ into a vector is required. Therefore, the retriever requires knowledge of an emb
300336
Embedders
301337
-----------------------------
302338

303-
Currently, this package supports two embedders: `OpenAIEmbeddings` and `SentenceTransformerEmbeddings`.
339+
Currently, this package supports several embedders:
340+
- `OpenAIEmbeddings`
341+
- `AzureOpenAIEmbeddings`
342+
- `VertexAIEmbeddings`
343+
- `CohereEmbeddings`
344+
- `SentenceTransformerEmbeddings`.
304345

305346
The `OpenAIEmbedder` was illustrated previously. Here is how to use the `SentenceTransformerEmbeddings`:
306347

307348
.. code:: python
308349
309-
from neo4j_graphrag.embeddings.sentence_transformers import SentenceTransformerEmbeddings
350+
from neo4j_graphrag.embeddings import SentenceTransformerEmbeddings
310351
311352
embedder = SentenceTransformerEmbeddings(model="all-MiniLM-L6-v2") # Note: this is the default model
312353
@@ -330,10 +371,10 @@ the following implementation of an embedder that wraps the `OllamaEmbedding` mod
330371
return embedding[0]
331372
332373
ollama_embedding = OllamaEmbedding(
333-
model_name="llama3",
334-
base_url="http://localhost:11434",
335-
ollama_additional_kwargs={"mirostat": 0},
336-
)
374+
model_name="llama3",
375+
base_url="http://localhost:11434",
376+
ollama_additional_kwargs={"mirostat": 0},
377+
)
337378
embedder = OllamaEmbedder(ollama_embedding)
338379
vector = embedder.embed_query("some text")
339380

0 commit comments

Comments
 (0)