-
Notifications
You must be signed in to change notification settings - Fork 4k
Python: #6499 Mistral AI Embedding Connector #7122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
eavanvalkenburg
merged 24 commits into
microsoft:main
from
nmoeller:issue-6499-Mistral-Ai-Connector-embeddings
Aug 5, 2024
Merged
Changes from 17 commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
1304cc1
Initial Commit for Mistral Embeddings with TestCases
nmoellerms 23272b3
Merge branch 'main' into issue-6499-Mistral-Ai-Connector-embeddings
nmoeller 8dbf320
small naming fixes more testcases
nmoellerms 05693d6
Merge branch 'issue-6499-Mistral-Ai-Connector-embeddings' of https://…
nmoellerms 4f48ca1
added some more test cases for completion
nmoellerms 0980e92
Merge branch 'main' into issue-6499-Mistral-Ai-Connector-embeddings
nmoeller dbd97a9
Merge branch 'main' into issue-6499-Mistral-Ai-Connector-embeddings
nmoeller c6a0f9b
Merge remote-tracking branch 'origin/main' into issue-6499-Mistral-Ai…
nmoellerms c370896
Merge branch 'main' into issue-6499-Mistral-Ai-Connector-embeddings
nmoeller 3dd1dd7
Merge remote-tracking branch 'origin/main' into issue-6499-Mistral-Ai…
nmoeller 2679d16
Merge branch 'issue-6499-Mistral-Ai-Connector-embeddings' of https://…
nmoeller b2c6fb0
Merge remote-tracking branch 'origin/main' into issue-6499-Mistral-Ai…
nmoeller df8a64e
Integrated PR Feedback
nmoeller c1b35c8
removed unencessary cast
nmoeller 2c039a4
Merge branch 'main' into issue-6499-Mistral-Ai-Connector-embeddings
nmoeller 1492e19
Merge branch 'main' into issue-6499-Mistral-Ai-Connector-embeddings
nmoeller 0cf610f
Merge branch 'main' into issue-6499-Mistral-Ai-Connector-embeddings
nmoeller fb3b8f2
addressed PR Feedback
nmoeller 08c4cc9
Merge branch 'issue-6499-Mistral-Ai-Connector-embeddings' of https://…
nmoeller 5323bea
Merge branch 'main' into issue-6499-Mistral-Ai-Connector-embeddings
nmoeller f92d4b6
Merge branch 'main' into issue-6499-Mistral-Ai-Connector-embeddings
nmoeller b09db5a
Update python/semantic_kernel/connectors/ai/mistral_ai/services/mistr…
nmoeller 0350506
fixed response type of generate embeddings
nmoeller 106cb93
Merge branch 'main' into issue-6499-Mistral-Ai-Connector-embeddings
nmoeller File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
112 changes: 112 additions & 0 deletions
112
python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
# Copyright (c) Microsoft. All rights reserved. | ||
|
||
import sys | ||
|
||
if sys.version_info >= (3, 12): | ||
from typing import Any, override # pragma: no cover | ||
else: | ||
from typing_extensions import Any, override # pragma: no cover | ||
import logging | ||
|
||
from mistralai.async_client import MistralAsyncClient | ||
from mistralai.models.embeddings import EmbeddingResponse | ||
from numpy import array, ndarray | ||
from pydantic import ValidationError | ||
|
||
from semantic_kernel.connectors.ai.embeddings.embedding_generator_base import EmbeddingGeneratorBase | ||
from semantic_kernel.connectors.ai.mistral_ai.settings.mistral_ai_settings import MistralAISettings | ||
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings | ||
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError, ServiceResponseException | ||
from semantic_kernel.utils.experimental_decorator import experimental_class | ||
|
||
logger: logging.Logger = logging.getLogger(__name__) | ||
|
||
|
||
@experimental_class | ||
class MistralAITextEmbedding(EmbeddingGeneratorBase): | ||
"""Mistral AI Inference Text Embedding Service.""" | ||
|
||
client: MistralAsyncClient | ||
|
||
def __init__( | ||
self, | ||
ai_model_id: str | None = None, | ||
api_key: str | None = None, | ||
service_id: str | None = None, | ||
env_file_path: str | None = None, | ||
env_file_encoding: str | None = None, | ||
client: MistralAsyncClient | None = None, | ||
) -> None: | ||
"""Initialize the Mistral AI Text Embedding service. | ||
|
||
If no arguments are provided, the service will attempt to load the settings from the environment. | ||
The following environment variables are used: | ||
- MISTRALAI_API_KEY | ||
- MISTRALAI_EMBEDDING_MODEL_ID | ||
|
||
Args: | ||
ai_model_id: (str | None): A string that is used to identify the model such as the model name. | ||
api_key (str | None): The API key for the Mistral AI service deployment. | ||
service_id (str | None): Service ID for the embedding completion service. | ||
env_file_path (str | None): The path to the environment file. | ||
env_file_encoding (str | None): The encoding of the environment file. | ||
client (MistralAsyncClient | None): The Mistral AI client to use. | ||
|
||
Raises: | ||
ServiceInitializationError: If an error occurs during initialization. | ||
""" | ||
try: | ||
mistralai_settings = MistralAISettings.create( | ||
api_key=api_key, | ||
embedding_model_id=ai_model_id, | ||
env_file_path=env_file_path, | ||
env_file_encoding=env_file_encoding, | ||
) | ||
except ValidationError as e: | ||
raise ServiceInitializationError(f"Failed to validate Mistral AI settings: {e}") from e | ||
|
||
if not mistralai_settings.embedding_model_id: | ||
raise ServiceInitializationError("The MistralAI embedding model ID is required.") | ||
|
||
if not client: | ||
client = MistralAsyncClient( | ||
api_key=mistralai_settings.api_key.get_secret_value() | ||
) | ||
|
||
super().__init__( | ||
service_id=service_id or mistralai_settings.embedding_model_id, | ||
ai_model_id=ai_model_id or mistralai_settings.embedding_model_id, | ||
client=client, | ||
) | ||
|
||
@override | ||
async def generate_embeddings( | ||
self, | ||
texts: list[str], | ||
settings: "PromptExecutionSettings | None" = None, | ||
**kwargs: Any, | ||
) -> ndarray: | ||
embedding_response: EmbeddingResponse = await self.generate_raw_embeddings(texts, settings, **kwargs) | ||
return array([item.embedding for item in embedding_response.data]) | ||
|
||
@override | ||
async def generate_raw_embeddings( | ||
self, | ||
texts: list[str], | ||
settings: "PromptExecutionSettings | None" = None, | ||
**kwargs: Any, | ||
) -> "EmbeddingResponse": | ||
nmoeller marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Generate embeddings from the Mistral AI service.""" | ||
try: | ||
|
||
response: EmbeddingResponse = await self.client.embeddings( | ||
nmoeller marked this conversation as resolved.
Show resolved
Hide resolved
|
||
model=self.ai_model_id, | ||
input=texts | ||
) | ||
except Exception as ex: | ||
raise ServiceResponseException( | ||
f"{type(self)} service failed to complete the embedding request.", | ||
ex, | ||
) from ex | ||
|
||
return response |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
79 changes: 79 additions & 0 deletions
79
python/tests/integration/embeddings/test_embedding_service.py
nmoeller marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Copyright (c) Microsoft. All rights reserved. | ||
|
||
import os | ||
|
||
import pytest | ||
|
||
import semantic_kernel as sk | ||
from semantic_kernel.connectors.ai.embeddings.embedding_generator_base import EmbeddingGeneratorBase | ||
from semantic_kernel.connectors.ai.mistral_ai import MistralAITextEmbedding | ||
from semantic_kernel.core_plugins.text_memory_plugin import TextMemoryPlugin | ||
from semantic_kernel.kernel import Kernel | ||
from semantic_kernel.memory.semantic_text_memory import SemanticTextMemory | ||
|
||
mistral_ai_setup: bool = False | ||
try: | ||
if os.environ["MISTRALAI_API_KEY"] and os.environ["MISTRALAI_EMBEDDING_MODEL_ID"]: | ||
mistral_ai_setup = True | ||
except KeyError: | ||
mistral_ai_setup = False | ||
|
||
|
||
pytestmark = pytest.mark.parametrize("embeddings_generator", | ||
[ | ||
pytest.param( | ||
MistralAITextEmbedding() if mistral_ai_setup else None, | ||
marks=pytest.mark.skipif(not mistral_ai_setup, reason="Mistral AI environment variables not set"), | ||
id="MistralEmbeddings" | ||
) | ||
] | ||
) | ||
|
||
|
||
@pytest.mark.asyncio(scope="module") | ||
async def test_embedding_service(kernel: Kernel, embeddings_generator: EmbeddingGeneratorBase): | ||
kernel.add_service(embeddings_generator) | ||
|
||
memory = SemanticTextMemory(storage=sk.memory.VolatileMemoryStore(), embeddings_generator=embeddings_generator) | ||
kernel.add_plugin(TextMemoryPlugin(memory), "TextMemoryPlugin") | ||
|
||
await memory.save_reference( | ||
"test", | ||
external_id="info1", | ||
text="this is a test", | ||
external_source_name="external source", | ||
) | ||
|
||
# Add some documents to the semantic memory | ||
await memory.save_information("test", id="info1", text="Sharks are fish.") | ||
await memory.save_information("test", id="info2", text="Whales are mammals.") | ||
await memory.save_information("test", id="info3", text="Penguins are birds.") | ||
await memory.save_information("test", id="info4", text="Dolphins are mammals.") | ||
await memory.save_information("test", id="info5", text="Flies are insects.") | ||
|
||
# Search for documents | ||
query = "What are mammals?" | ||
result = await memory.search("test", query, limit=2, min_relevance_score=0.0) | ||
print(f"Query: {query}") | ||
print(f"\tAnswer 1: {result[0].text}") | ||
print(f"\tAnswer 2: {result[1].text}\n") | ||
assert "mammals." in result[0].text | ||
assert "mammals." in result[1].text | ||
|
||
query = "What are fish?" | ||
result = await memory.search("test", query, limit=1, min_relevance_score=0.0) | ||
print(f"Query: {query}") | ||
print(f"\tAnswer: {result[0].text}\n") | ||
assert result[0].text == "Sharks are fish." | ||
|
||
query = "What are insects?" | ||
result = await memory.search("test", query, limit=1, min_relevance_score=0.0) | ||
print(f"Query: {query}") | ||
print(f"\tAnswer: {result[0].text}\n") | ||
assert result[0].text == "Flies are insects." | ||
|
||
query = "What are birds?" | ||
result = await memory.search("test", query, limit=1, min_relevance_score=0.0) | ||
print(f"Query: {query}") | ||
print(f"\tAnswer: {result[0].text}\n") | ||
assert result[0].text == "Penguins are birds." |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.