Skip to content

Commit 1edec5d

Browse files
authored
Mistral support (#142)
* Add MistralAIEmbeddings and MistralAILLM * Add examples * Add unit tests for Mistral embedder and LLM classes * Update docs * Ruff and modify example * Fixed docstring * Address comments * Fix tests * Fix get_message private method * Fixed mypy errors * Replace | operand with explicit typing.Union * Update CHANGELOG * Generic errors * Check for api_key in kwargs first before using MISTRAL_API_KEY env var * PR comments and update docs * Removed types-requests * Removed redundant items in CHANGELOG * Add EmbeddingsGenerationError * Resolve deps issues * Upload poetry.lock * Modified check for failure to retrieve embeddings * EmbeddingsGenerationError for embed_query
1 parent 7cb652d commit 1edec5d

16 files changed

+1035
-469
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
- Added `template` validation in `PromptTemplate` class upon construction.
88
- `custom_prompt` arg is now converted to `Text2CypherTemplate` class within the `Text2CypherRetriever.get_search_results` method.
99
- `Text2CypherTemplate` and `RAGTemplate` prompt templates now require `query_text` arg and will error if it is not present. Previous `query_text` aliases may be used, but will warn of deprecation.
10+
- Examples demonstrating the use of Mistral embeddings and LLM in RAG pipelines.
1011
- Fixed bug in `Text2CypherRetriever` using `custom_prompt` arg where the `search` method would not inject the `query_text` content.
1112
- Added feature to include kwargs in `Text2CypherRetriever.search()` that will be injected into a custom prompt, if provided.
1213
- Added validation to `custom_prompt` parameter of `Text2CypherRetriever` to ensure that `query_text` placeholder exists in prompt.
@@ -15,6 +16,7 @@
1516
- Added unit tests for the Vertex AI LLM class.
1617
- Added support for Cohere LLM and embeddings - added optional dependency to `cohere`.
1718
- Added support for Anthropic LLM - added optional dependency to `anthropic`.
19+
- Added support for MistralAI LLM - added optional dependency to `mistralai`.
1820

1921
### Fixed
2022
- Resolved import issue with the Vertex AI Embeddings class.

docs/source/api.rst

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,24 @@ OpenAIEmbeddings
156156
.. autoclass:: neo4j_graphrag.embeddings.openai.OpenAIEmbeddings
157157
:members:
158158

159+
AzureOpenAIEmbeddings
160+
=====================
161+
162+
.. autoclass:: neo4j_graphrag.embeddings.openai.AzureOpenAIEmbeddings
163+
:members:
164+
159165
VertexAIEmbeddings
160166
==================
161167

162168
.. autoclass:: neo4j_graphrag.embeddings.vertexai.VertexAIEmbeddings
163169
:members:
164170

171+
MistralAIEmbeddings
172+
===================
173+
174+
.. autoclass:: neo4j_graphrag.embeddings.mistral.MistralAIEmbeddings
175+
:members:
176+
165177
CohereEmbeddings
166178
================
167179

@@ -218,6 +230,13 @@ CohereLLM
218230
:members:
219231

220232

233+
MistralAILLM
234+
------------
235+
236+
.. autoclass:: neo4j_graphrag.llm.mistralai_llm.MistralAILLM
237+
:members:
238+
239+
221240
PromptTemplate
222241
==============
223242

@@ -282,6 +301,8 @@ Errors
282301

283302
* :class:`neo4j_graphrag.exceptions.RetrieverInitializationError`
284303

304+
* :class:`neo4j_graphrag.exceptions.EmbeddingsGenerationError`
305+
285306
* :class:`neo4j_graphrag.exceptions.SearchValidationError`
286307

287308
* :class:`neo4j_graphrag.exceptions.FilterValidationError`

docs/source/user_guide_rag.rst

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -393,14 +393,16 @@ into a vector is required. Therefore, the retriever requires knowledge of an emb
393393
Embedders
394394
-----------------------------
395395

396-
Currently, this package supports several embedders:
397-
- `OpenAIEmbeddings`
398-
- `AzureOpenAIEmbeddings`
399-
- `VertexAIEmbeddings`
400-
- `CohereEmbeddings`
401-
- `SentenceTransformerEmbeddings`.
402-
403-
The `OpenAIEmbedder` was illustrated previously. Here is how to use the `SentenceTransformerEmbeddings`:
396+
Currently, this package supports the following embedders:
397+
398+
- :ref:`openaiembeddings`
399+
- :ref:`sentencetransformerembeddings`
400+
- :ref:`vertexaiembeddings`
401+
- :ref:`mistralaiembeddings`
402+
- :ref:`cohereembeddings`
403+
- :ref:`azureopenaiembeddings`
404+
405+
The `OpenAIEmbeddings` was illustrated previously. Here is how to use the `SentenceTransformerEmbeddings`:
404406

405407
.. code:: python
406408

examples/graphrag_with_mistral.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""End to end example of building a RAG pipeline backed by a Neo4j database.
2+
Requires MISTRAL_API_KEY to be in the env var.
3+
4+
This example illustrates:
5+
- VectorCypherRetriever with a custom formatter function to extract relevant
6+
context from neo4j result
7+
- Logging configuration
8+
"""
9+
10+
import logging
11+
12+
import neo4j
13+
from neo4j_graphrag.embeddings.mistral import MistralAIEmbeddings
14+
from neo4j_graphrag.generation import GraphRAG
15+
from neo4j_graphrag.indexes import create_vector_index
16+
from neo4j_graphrag.llm.mistralai_llm import MistralAILLM
17+
from neo4j_graphrag.retrievers import VectorCypherRetriever
18+
from neo4j_graphrag.types import RetrieverResultItem
19+
20+
URI = "neo4j://localhost:7687"
21+
AUTH = ("neo4j", "password")
22+
DATABASE = "neo4j"
23+
INDEX_NAME = "moviePlotsEmbedding"
24+
25+
26+
# setup logger config
27+
logger = logging.getLogger("neo4j_graphrag")
28+
logging.basicConfig(format="%(asctime)s - %(message)s")
29+
logger.setLevel(logging.DEBUG)
30+
31+
32+
def formatter(record: neo4j.Record) -> RetrieverResultItem:
33+
return RetrieverResultItem(content=f'{record.get("title")}: {record.get("plot")}')
34+
35+
36+
driver = neo4j.GraphDatabase.driver(
37+
URI,
38+
auth=AUTH,
39+
)
40+
41+
create_vector_index(
42+
driver,
43+
INDEX_NAME,
44+
label="Document",
45+
embedding_property="vectorProperty",
46+
dimensions=1024,
47+
similarity_fn="cosine",
48+
)
49+
50+
51+
embedder = MistralAIEmbeddings()
52+
53+
retriever = VectorCypherRetriever(
54+
driver,
55+
index_name=INDEX_NAME,
56+
retrieval_query="with node, score return node.title as title, node.plot as plot",
57+
result_formatter=formatter,
58+
embedder=embedder,
59+
)
60+
61+
llm = MistralAILLM(model_name="mistral-small-latest")
62+
63+
rag = GraphRAG(retriever=retriever, llm=llm)
64+
65+
result = rag.search("Tell me more about Avatar movies")
66+
print(result.answer)
67+
68+
driver.close()
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from __future__ import annotations
2+
3+
from random import random
4+
5+
from neo4j import GraphDatabase
6+
from neo4j_graphrag.embeddings.mistral import MistralAIEmbeddings
7+
from neo4j_graphrag.indexes import create_vector_index
8+
from neo4j_graphrag.retrievers import VectorRetriever
9+
10+
URI = "neo4j://localhost:7687"
11+
AUTH = ("neo4j", "password")
12+
13+
INDEX_NAME = "embedding-name"
14+
DIMENSION = 1024
15+
16+
# Connect to Neo4j database
17+
driver = GraphDatabase.driver(URI, auth=AUTH)
18+
19+
embedder = MistralAIEmbeddings()
20+
21+
# Creating the index
22+
create_vector_index(
23+
driver,
24+
INDEX_NAME,
25+
label="Document",
26+
embedding_property="vectorProperty",
27+
dimensions=DIMENSION,
28+
similarity_fn="euclidean",
29+
)
30+
31+
# Initialize the retriever
32+
retriever = VectorRetriever(driver, INDEX_NAME, embedder)
33+
34+
# Upsert the query
35+
vector = [random() for _ in range(DIMENSION)]
36+
insert_query = (
37+
"MERGE (n:Document {id: $id})"
38+
"WITH n "
39+
"CALL db.create.setNodeVectorProperty(n, 'vectorProperty', $vector)"
40+
"RETURN n"
41+
)
42+
parameters = {
43+
"id": 0,
44+
"vector": vector,
45+
}
46+
driver.execute_query(insert_query, parameters)
47+
48+
# Perform the similarity search for a text query
49+
query_text = "Find me a book about Fremen"
50+
print(retriever.search(query_text=query_text, top_k=5))

0 commit comments

Comments
 (0)