Skip to content

Python: #6499 Mistral AI Embedding Connector #7122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
1304cc1
Initial Commit for Mistral Embeddings with TestCases
nmoellerms Jul 6, 2024
23272b3
Merge branch 'main' into issue-6499-Mistral-Ai-Connector-embeddings
nmoeller Jul 6, 2024
8dbf320
small naming fixes more testcases
nmoellerms Jul 8, 2024
05693d6
Merge branch 'issue-6499-Mistral-Ai-Connector-embeddings' of https://…
nmoellerms Jul 8, 2024
4f48ca1
added some more test cases for completion
nmoellerms Jul 8, 2024
0980e92
Merge branch 'main' into issue-6499-Mistral-Ai-Connector-embeddings
nmoeller Jul 8, 2024
dbd97a9
Merge branch 'main' into issue-6499-Mistral-Ai-Connector-embeddings
nmoeller Jul 9, 2024
c6a0f9b
Merge remote-tracking branch 'origin/main' into issue-6499-Mistral-Ai…
nmoellerms Jul 9, 2024
c370896
Merge branch 'main' into issue-6499-Mistral-Ai-Connector-embeddings
nmoeller Jul 10, 2024
3dd1dd7
Merge remote-tracking branch 'origin/main' into issue-6499-Mistral-Ai…
nmoeller Jul 22, 2024
2679d16
Merge branch 'issue-6499-Mistral-Ai-Connector-embeddings' of https://…
nmoeller Jul 22, 2024
b2c6fb0
Merge remote-tracking branch 'origin/main' into issue-6499-Mistral-Ai…
nmoeller Jul 22, 2024
df8a64e
Integrated PR Feedback
nmoeller Jul 22, 2024
c1b35c8
removed unencessary cast
nmoeller Jul 22, 2024
2c039a4
Merge branch 'main' into issue-6499-Mistral-Ai-Connector-embeddings
nmoeller Jul 22, 2024
1492e19
Merge branch 'main' into issue-6499-Mistral-Ai-Connector-embeddings
nmoeller Jul 23, 2024
0cf610f
Merge branch 'main' into issue-6499-Mistral-Ai-Connector-embeddings
nmoeller Jul 24, 2024
fb3b8f2
addressed PR Feedback
nmoeller Jul 25, 2024
08c4cc9
Merge branch 'issue-6499-Mistral-Ai-Connector-embeddings' of https://…
nmoeller Jul 25, 2024
5323bea
Merge branch 'main' into issue-6499-Mistral-Ai-Connector-embeddings
nmoeller Jul 25, 2024
f92d4b6
Merge branch 'main' into issue-6499-Mistral-Ai-Connector-embeddings
nmoeller Jul 26, 2024
b09db5a
Update python/semantic_kernel/connectors/ai/mistral_ai/services/mistr…
nmoeller Aug 1, 2024
0350506
fixed response type of generate embeddings
nmoeller Aug 5, 2024
106cb93
Merge branch 'main' into issue-6499-Mistral-Ai-Connector-embeddings
nmoeller Aug 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/semantic_kernel/connectors/ai/mistral_ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
MistralAIChatPromptExecutionSettings,
)
from semantic_kernel.connectors.ai.mistral_ai.services.mistral_ai_chat_completion import MistralAIChatCompletion
from semantic_kernel.connectors.ai.mistral_ai.services.mistral_ai_text_embedding import MistralAITextEmbedding

__all__ = [
"MistralAIChatCompletion",
"MistralAIChatPromptExecutionSettings",
"MistralAITextEmbedding",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) Microsoft. All rights reserved.


from typing import Any

from mistralai.async_client import MistralAsyncClient
from mistralai.models.embeddings import EmbeddingResponse
from numpy import array, ndarray
from pydantic import ValidationError

from semantic_kernel.connectors.ai.embeddings.embedding_generator_base import EmbeddingGeneratorBase
from semantic_kernel.connectors.ai.mistral_ai.settings.mistral_ai_settings import MistralAISettings
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError, ServiceResponseException
from semantic_kernel.utils.experimental_decorator import experimental_class


@experimental_class
class MistralAITextEmbedding(EmbeddingGeneratorBase):
"""Mistral AI Inference Text Embedding Service."""

client: MistralAsyncClient

def __init__(
self,
ai_model_id: str | None = None,
api_key: str | None = None,
service_id: str | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
client: MistralAsyncClient | None = None,
) -> None:
"""Initialize the Mistral AI Text Embedding service.

If no arguments are provided, the service will attempt to load the settings from the environment.
The following environment variables are used:
- MISTRALAI_API_KEY
- MISTRALAI_EMBEDDING_MODEL_ID

Args:
ai_model_id: (str | None): A string that is used to identify the model such as the model name.
api_key (str | None): The API key for the Mistral AI service deployment.
service_id (str | None): Service ID for the embedding completion service.
env_file_path (str | None): The path to the environment file.
env_file_encoding (str | None): The encoding of the environment file.
client (MistralAsyncClient | None): The Mistral AI client to use.

Raises:
ServiceInitializationError: If an error occurs during initialization.
"""
try:
mistralai_settings = MistralAISettings.create(
api_key=api_key,
embedding_model_id=ai_model_id,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
except ValidationError as e:
raise ServiceInitializationError(f"Failed to validate Mistral AI settings: {e}") from e

if not mistralai_settings.chat_model_id:
raise ServiceInitializationError("The MistralAI embedding model ID is required.")

if not client:
client = MistralAsyncClient(
api_key=mistralai_settings.api_key.get_secret_value()
)

super().__init__(
service_id=service_id or mistralai_settings.embedding_model_id,
ai_model_id=ai_model_id or mistralai_settings.embedding_model_id,
client=client,
)

async def generate_embeddings(self, texts: list[str], **kwargs: Any) -> ndarray:
"""Generate embeddings from the Mistral AI service."""
try:

response: EmbeddingResponse = await self.client.embeddings(
model=self.ai_model_id,
input=texts
)
except Exception as ex:
raise ServiceResponseException(
f"{type(self)} service failed to complete the embedding request.",
ex,
) from ex

return array([array(item.embedding) for item in response.data])
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ class MistralAISettings(KernelBaseSettings):
(Env var MISTRALAI_API_KEY)
- chat_model_id: str | None - The The Mistral AI chat model ID to use see https://docs.mistral.ai/getting-started/models/.
(Env var MISTRALAI_CHAT_MODEL_ID)
- embedding_model_id: str | None - The The Mistral AI embedding model ID to use see https://docs.mistral.ai/getting-started/models/.
(Env var MISTRALAI_EMBEDDING_MODEL_ID)
- env_file_path: str | None - if provided, the .env settings are read from this file path location
"""

env_prefix: ClassVar[str] = "MISTRALAI_"

api_key: SecretStr
chat_model_id: str | None = None
embedding_model_id: str | None = None
3 changes: 2 additions & 1 deletion python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,8 @@ def mistralai_unit_test_env(monkeypatch, exclude_list, override_env_param_dict):

env_vars = {
"MISTRALAI_CHAT_MODEL_ID": "test_chat_model_id",
"MISTRALAI_API_KEY": "test_api_key"
"MISTRALAI_API_KEY": "test_api_key",
"MISTRALAI_EMBEDDING_MODEL_ID": "test_embedding_model_id",
}

env_vars.update(override_env_param_dict)
Expand Down
79 changes: 79 additions & 0 deletions python/tests/integration/embeddings/test_embedding_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) Microsoft. All rights reserved.

import os

import pytest

import semantic_kernel as sk
from semantic_kernel.connectors.ai.embeddings.embedding_generator_base import EmbeddingGeneratorBase
from semantic_kernel.connectors.ai.mistral_ai import MistralAITextEmbedding
from semantic_kernel.core_plugins.text_memory_plugin import TextMemoryPlugin
from semantic_kernel.kernel import Kernel
from semantic_kernel.memory.semantic_text_memory import SemanticTextMemory

mistral_ai_setup: bool = False
try:
if os.environ["MISTRALAI_API_KEY"] and os.environ["MISTRALAI_CHAT_MODEL_ID"]:
mistral_ai_setup = True
except KeyError:
mistral_ai_setup = False


pytestmark = pytest.mark.parametrize("embeddings_generator",
[
pytest.param(
MistralAITextEmbedding() if mistral_ai_setup else None,
marks=pytest.mark.skipif(not mistral_ai_setup, reason="Mistral AI environment variables not set"),
id="MistralEmbeddings"
)
]
)


@pytest.mark.asyncio(scope="module")
async def test_embedding_service(kernel: Kernel, embeddings_generator: EmbeddingGeneratorBase):
kernel.add_service(embeddings_generator)

memory = SemanticTextMemory(storage=sk.memory.VolatileMemoryStore(), embeddings_generator=embeddings_generator)
kernel.add_plugin(TextMemoryPlugin(memory), "TextMemoryPlugin")

await memory.save_reference(
"test",
external_id="info1",
text="this is a test",
external_source_name="external source",
)

# Add some documents to the semantic memory
await memory.save_information("test", id="info1", text="Sharks are fish.")
await memory.save_information("test", id="info2", text="Whales are mammals.")
await memory.save_information("test", id="info3", text="Penguins are birds.")
await memory.save_information("test", id="info4", text="Dolphins are mammals.")
await memory.save_information("test", id="info5", text="Flies are insects.")

# Search for documents
query = "What are mammals?"
result = await memory.search("test", query, limit=2, min_relevance_score=0.0)
print(f"Query: {query}")
print(f"\tAnswer 1: {result[0].text}")
print(f"\tAnswer 2: {result[1].text}\n")
assert "mammals." in result[0].text
assert "mammals." in result[1].text

query = "What are fish?"
result = await memory.search("test", query, limit=1, min_relevance_score=0.0)
print(f"Query: {query}")
print(f"\tAnswer: {result[0].text}\n")
assert result[0].text == "Sharks are fish."

query = "What are insects?"
result = await memory.search("test", query, limit=1, min_relevance_score=0.0)
print(f"Query: {query}")
print(f"\tAnswer: {result[0].text}\n")
assert result[0].text == "Flies are insects."

query = "What are birds?"
result = await memory.search("test", query, limit=1, min_relevance_score=0.0)
print(f"Query: {query}")
print(f"\tAnswer: {result[0].text}\n")
assert result[0].text == "Penguins are birds."
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) Microsoft. All rights reserved.

from unittest.mock import AsyncMock, MagicMock

import pytest
from mistralai.async_client import MistralAsyncClient

from semantic_kernel.connectors.ai.mistral_ai.services.mistral_ai_text_embedding import MistralAITextEmbedding
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError, ServiceResponseException


def test_embedding_with_env_variables(mistralai_unit_test_env):
text_embedding = MistralAITextEmbedding()
assert text_embedding.ai_model_id == "test_embedding_model_id"
assert text_embedding.client._api_key == "test_api_key"


def test_embedding_with_client(mistralai_unit_test_env):
client = MagicMock(spec=MistralAsyncClient)
text_embedding = MistralAITextEmbedding(client=client)
assert text_embedding.client == client
assert text_embedding.ai_model_id == "test_embedding_model_id"


def test_embedding_with_api_key(mistralai_unit_test_env):
text_embedding = MistralAITextEmbedding(api_key="overwrite-api-key")
assert text_embedding.client._api_key == "overwrite-api-key"


def test_embedding_with_model(mistralai_unit_test_env):
text_embedding = MistralAITextEmbedding(ai_model_id="overwrite-model")
assert text_embedding.ai_model_id == "overwrite-model"


@pytest.mark.parametrize("exclude_list", [["MISTRALAI_CHAT_MODEL_ID"]], indirect=True)
def test_embedding_missing_model(mistralai_unit_test_env):
with pytest.raises(ServiceInitializationError):
MistralAITextEmbedding(
env_file_path="test.env",
)


@pytest.mark.parametrize("exclude_list", [["MISTRALAI_API_KEY"]], indirect=True)
def test_embedding_missing_api_key(mistralai_unit_test_env):
with pytest.raises(ServiceInitializationError):
MistralAITextEmbedding(
env_file_path="test.env",
)


@pytest.mark.asyncio
async def test_embedding_generate_embedding(mistralai_unit_test_env):
mock_client = AsyncMock(spec=MistralAsyncClient)
mock_embedding_response = MagicMock(data=[MagicMock(embedding=[1, 2, 3, 4, 5])])
mock_client.embeddings.return_value = mock_embedding_response
text_embedding = MistralAITextEmbedding(client=mock_client)
embedding = await text_embedding.generate_embeddings(["test"])
assert embedding.tolist() == [[1, 2, 3, 4, 5]]


@pytest.mark.asyncio
async def test_embedding_generate_embedding_exception(mistralai_unit_test_env):
mock_client = AsyncMock(spec=MistralAsyncClient)
mock_client.embeddings.side_effect = Exception("Test Exception")
text_embedding = MistralAITextEmbedding(client=mock_client)
with pytest.raises(ServiceResponseException):
await text_embedding.generate_embeddings(["test"])
Loading