Skip to content

Python: #6499 Mistral AI Embedding Connector #7122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
1304cc1
Initial Commit for Mistral Embeddings with TestCases
nmoellerms Jul 6, 2024
23272b3
Merge branch 'main' into issue-6499-Mistral-Ai-Connector-embeddings
nmoeller Jul 6, 2024
8dbf320
small naming fixes more testcases
nmoellerms Jul 8, 2024
05693d6
Merge branch 'issue-6499-Mistral-Ai-Connector-embeddings' of https://…
nmoellerms Jul 8, 2024
4f48ca1
added some more test cases for completion
nmoellerms Jul 8, 2024
0980e92
Merge branch 'main' into issue-6499-Mistral-Ai-Connector-embeddings
nmoeller Jul 8, 2024
dbd97a9
Merge branch 'main' into issue-6499-Mistral-Ai-Connector-embeddings
nmoeller Jul 9, 2024
c6a0f9b
Merge remote-tracking branch 'origin/main' into issue-6499-Mistral-Ai…
nmoellerms Jul 9, 2024
c370896
Merge branch 'main' into issue-6499-Mistral-Ai-Connector-embeddings
nmoeller Jul 10, 2024
3dd1dd7
Merge remote-tracking branch 'origin/main' into issue-6499-Mistral-Ai…
nmoeller Jul 22, 2024
2679d16
Merge branch 'issue-6499-Mistral-Ai-Connector-embeddings' of https://…
nmoeller Jul 22, 2024
b2c6fb0
Merge remote-tracking branch 'origin/main' into issue-6499-Mistral-Ai…
nmoeller Jul 22, 2024
df8a64e
Integrated PR Feedback
nmoeller Jul 22, 2024
c1b35c8
removed unencessary cast
nmoeller Jul 22, 2024
2c039a4
Merge branch 'main' into issue-6499-Mistral-Ai-Connector-embeddings
nmoeller Jul 22, 2024
1492e19
Merge branch 'main' into issue-6499-Mistral-Ai-Connector-embeddings
nmoeller Jul 23, 2024
0cf610f
Merge branch 'main' into issue-6499-Mistral-Ai-Connector-embeddings
nmoeller Jul 24, 2024
fb3b8f2
addressed PR Feedback
nmoeller Jul 25, 2024
08c4cc9
Merge branch 'issue-6499-Mistral-Ai-Connector-embeddings' of https://…
nmoeller Jul 25, 2024
5323bea
Merge branch 'main' into issue-6499-Mistral-Ai-Connector-embeddings
nmoeller Jul 25, 2024
f92d4b6
Merge branch 'main' into issue-6499-Mistral-Ai-Connector-embeddings
nmoeller Jul 26, 2024
b09db5a
Update python/semantic_kernel/connectors/ai/mistral_ai/services/mistr…
nmoeller Aug 1, 2024
0350506
fixed response type of generate embeddings
nmoeller Aug 5, 2024
106cb93
Merge branch 'main' into issue-6499-Mistral-Ai-Connector-embeddings
nmoeller Aug 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/python-integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ jobs:
ACA_POOL_MANAGEMENT_ENDPOINT: ${{secrets.ACA_POOL_MANAGEMENT_ENDPOINT}}
MISTRALAI_API_KEY: ${{secrets.MISTRALAI_API_KEY}}
MISTRALAI_CHAT_MODEL_ID: ${{ vars.MISTRALAI_CHAT_MODEL_ID }}
MISTRALAI_EMBEDDING_MODEL_ID: ${{ vars.MISTRALAI_EMBEDDING_MODEL_ID }}
OLLAMA_MODEL: "${{ matrix.os == 'ubuntu-latest' && vars.OLLAMA_MODEL || '' }}" # phi3
run: |
if ${{ matrix.os == 'ubuntu-latest' }}; then
Expand Down Expand Up @@ -192,6 +193,7 @@ jobs:
ACA_POOL_MANAGEMENT_ENDPOINT: ${{secrets.ACA_POOL_MANAGEMENT_ENDPOINT}}
MISTRALAI_API_KEY: ${{secrets.MISTRALAI_API_KEY}}
MISTRALAI_CHAT_MODEL_ID: ${{ vars.MISTRALAI_CHAT_MODEL_ID }}
MISTRALAI_EMBEDDING_MODEL_ID: ${{ vars.MISTRALAI_EMBEDDING_MODEL_ID }}
OLLAMA_MODEL: "${{ matrix.os == 'ubuntu-latest' && vars.OLLAMA_MODEL || '' }}" # phi3
run: |
if ${{ matrix.os == 'ubuntu-latest' }}; then
Expand Down
2 changes: 2 additions & 0 deletions python/semantic_kernel/connectors/ai/mistral_ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
MistralAIChatPromptExecutionSettings,
)
from semantic_kernel.connectors.ai.mistral_ai.services.mistral_ai_chat_completion import MistralAIChatCompletion
from semantic_kernel.connectors.ai.mistral_ai.services.mistral_ai_text_embedding import MistralAITextEmbedding

__all__ = [
"MistralAIChatCompletion",
"MistralAIChatPromptExecutionSettings",
"MistralAITextEmbedding",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) Microsoft. All rights reserved.

import sys

if sys.version_info >= (3, 12):
from typing import Any, override # pragma: no cover
else:
from typing_extensions import Any, override # pragma: no cover
import logging

from mistralai.async_client import MistralAsyncClient
from mistralai.models.embeddings import EmbeddingResponse
from numpy import array, ndarray
from pydantic import ValidationError

from semantic_kernel.connectors.ai.embeddings.embedding_generator_base import EmbeddingGeneratorBase
from semantic_kernel.connectors.ai.mistral_ai.settings.mistral_ai_settings import MistralAISettings
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError, ServiceResponseException
from semantic_kernel.utils.experimental_decorator import experimental_class

logger: logging.Logger = logging.getLogger(__name__)


@experimental_class
class MistralAITextEmbedding(EmbeddingGeneratorBase):
"""Mistral AI Inference Text Embedding Service."""

client: MistralAsyncClient

def __init__(
self,
ai_model_id: str | None = None,
api_key: str | None = None,
service_id: str | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
client: MistralAsyncClient | None = None,
) -> None:
"""Initialize the Mistral AI Text Embedding service.

If no arguments are provided, the service will attempt to load the settings from the environment.
The following environment variables are used:
- MISTRALAI_API_KEY
- MISTRALAI_EMBEDDING_MODEL_ID

Args:
ai_model_id: (str | None): A string that is used to identify the model such as the model name.
api_key (str | None): The API key for the Mistral AI service deployment.
service_id (str | None): Service ID for the embedding completion service.
env_file_path (str | None): The path to the environment file.
env_file_encoding (str | None): The encoding of the environment file.
client (MistralAsyncClient | None): The Mistral AI client to use.

Raises:
ServiceInitializationError: If an error occurs during initialization.
"""
try:
mistralai_settings = MistralAISettings.create(
api_key=api_key,
embedding_model_id=ai_model_id,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
except ValidationError as e:
raise ServiceInitializationError(f"Failed to validate Mistral AI settings: {e}") from e

if not mistralai_settings.embedding_model_id:
raise ServiceInitializationError("The MistralAI embedding model ID is required.")

if not client:
client = MistralAsyncClient(
api_key=mistralai_settings.api_key.get_secret_value()
)

super().__init__(
service_id=service_id or mistralai_settings.embedding_model_id,
ai_model_id=ai_model_id or mistralai_settings.embedding_model_id,
client=client,
)

@override
async def generate_embeddings(
self,
texts: list[str],
settings: "PromptExecutionSettings | None" = None,
**kwargs: Any,
) -> ndarray:
embedding_response: EmbeddingResponse = await self.generate_raw_embeddings(texts, settings, **kwargs)
return array([item.embedding for item in embedding_response.data])

@override
async def generate_raw_embeddings(
self,
texts: list[str],
settings: "PromptExecutionSettings | None" = None,
**kwargs: Any,
) -> "EmbeddingResponse":
"""Generate embeddings from the Mistral AI service."""
try:

response: EmbeddingResponse = await self.client.embeddings(
model=self.ai_model_id,
input=texts
)
except Exception as ex:
raise ServiceResponseException(
f"{type(self)} service failed to complete the embedding request.",
ex,
) from ex

return response
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ class MistralAISettings(KernelBaseSettings):
(Env var MISTRALAI_API_KEY)
- chat_model_id: str | None - The The Mistral AI chat model ID to use see https://docs.mistral.ai/getting-started/models/.
(Env var MISTRALAI_CHAT_MODEL_ID)
- embedding_model_id: str | None - The The Mistral AI embedding model ID to use see https://docs.mistral.ai/getting-started/models/.
(Env var MISTRALAI_EMBEDDING_MODEL_ID)
- env_file_path: str | None - if provided, the .env settings are read from this file path location
"""

env_prefix: ClassVar[str] = "MISTRALAI_"

api_key: SecretStr
chat_model_id: str | None = None
embedding_model_id: str | None = None
6 changes: 5 additions & 1 deletion python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,11 @@ def mistralai_unit_test_env(monkeypatch, exclude_list, override_env_param_dict):
if override_env_param_dict is None:
override_env_param_dict = {}

env_vars = {"MISTRALAI_CHAT_MODEL_ID": "test_chat_model_id", "MISTRALAI_API_KEY": "test_api_key"}
env_vars = {
"MISTRALAI_CHAT_MODEL_ID": "test_chat_model_id",
"MISTRALAI_API_KEY": "test_api_key",
"MISTRALAI_EMBEDDING_MODEL_ID": "test_embedding_model_id",
}

env_vars.update(override_env_param_dict)

Expand Down
79 changes: 79 additions & 0 deletions python/tests/integration/embeddings/test_embedding_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) Microsoft. All rights reserved.

import os

import pytest

import semantic_kernel as sk
from semantic_kernel.connectors.ai.embeddings.embedding_generator_base import EmbeddingGeneratorBase
from semantic_kernel.connectors.ai.mistral_ai import MistralAITextEmbedding
from semantic_kernel.core_plugins.text_memory_plugin import TextMemoryPlugin
from semantic_kernel.kernel import Kernel
from semantic_kernel.memory.semantic_text_memory import SemanticTextMemory

mistral_ai_setup: bool = False
try:
if os.environ["MISTRALAI_API_KEY"] and os.environ["MISTRALAI_EMBEDDING_MODEL_ID"]:
mistral_ai_setup = True
except KeyError:
mistral_ai_setup = False


pytestmark = pytest.mark.parametrize("embeddings_generator",
[
pytest.param(
MistralAITextEmbedding() if mistral_ai_setup else None,
marks=pytest.mark.skipif(not mistral_ai_setup, reason="Mistral AI environment variables not set"),
id="MistralEmbeddings"
)
]
)


@pytest.mark.asyncio(scope="module")
async def test_embedding_service(kernel: Kernel, embeddings_generator: EmbeddingGeneratorBase):
kernel.add_service(embeddings_generator)

memory = SemanticTextMemory(storage=sk.memory.VolatileMemoryStore(), embeddings_generator=embeddings_generator)
kernel.add_plugin(TextMemoryPlugin(memory), "TextMemoryPlugin")

await memory.save_reference(
"test",
external_id="info1",
text="this is a test",
external_source_name="external source",
)

# Add some documents to the semantic memory
await memory.save_information("test", id="info1", text="Sharks are fish.")
await memory.save_information("test", id="info2", text="Whales are mammals.")
await memory.save_information("test", id="info3", text="Penguins are birds.")
await memory.save_information("test", id="info4", text="Dolphins are mammals.")
await memory.save_information("test", id="info5", text="Flies are insects.")

# Search for documents
query = "What are mammals?"
result = await memory.search("test", query, limit=2, min_relevance_score=0.0)
print(f"Query: {query}")
print(f"\tAnswer 1: {result[0].text}")
print(f"\tAnswer 2: {result[1].text}\n")
assert "mammals." in result[0].text
assert "mammals." in result[1].text

query = "What are fish?"
result = await memory.search("test", query, limit=1, min_relevance_score=0.0)
print(f"Query: {query}")
print(f"\tAnswer: {result[0].text}\n")
assert result[0].text == "Sharks are fish."

query = "What are insects?"
result = await memory.search("test", query, limit=1, min_relevance_score=0.0)
print(f"Query: {query}")
print(f"\tAnswer: {result[0].text}\n")
assert result[0].text == "Flies are insects."

query = "What are birds?"
result = await memory.search("test", query, limit=1, min_relevance_score=0.0)
print(f"Query: {query}")
print(f"\tAnswer: {result[0].text}\n")
assert result[0].text == "Penguins are birds."
Original file line number Diff line number Diff line change
Expand Up @@ -133,20 +133,53 @@ def test_mistral_ai_chat_completion_init(mistralai_unit_test_env) -> None:
mistral_ai_chat_completion = MistralAIChatCompletion()

assert mistral_ai_chat_completion.ai_model_id == mistralai_unit_test_env["MISTRALAI_CHAT_MODEL_ID"]
assert mistral_ai_chat_completion.async_client._api_key == mistralai_unit_test_env["MISTRALAI_API_KEY"]
assert isinstance(mistral_ai_chat_completion, ChatCompletionClientBase)


@pytest.mark.parametrize("exclude_list", [["MISTRALAI_API_KEY"]], indirect=True)
def test_mistral_ai_chat_completion_init_with_empty_api_key(mistralai_unit_test_env) -> None:
ai_model_id = "test_model_id"
@pytest.mark.parametrize("exclude_list", [["MISTRALAI_API_KEY", "MISTRALAI_CHAT_MODEL_ID"]], indirect=True)
def test_mistral_ai_chat_completion_init_constructor(mistralai_unit_test_env) -> None:
# Test successful initialization
mistral_ai_chat_completion = MistralAIChatCompletion(
api_key="overwrite_api_key",
ai_model_id="overwrite_model_id",
env_file_path="test.env",
)

assert mistral_ai_chat_completion.ai_model_id == "overwrite_model_id"
assert mistral_ai_chat_completion.async_client._api_key == "overwrite_api_key"
assert isinstance(mistral_ai_chat_completion, ChatCompletionClientBase)


@pytest.mark.parametrize("exclude_list", [["MISTRALAI_API_KEY", "MISTRALAI_CHAT_MODEL_ID"]], indirect=True)
def test_mistral_ai_chat_completion_init_constructor_missing_model(mistralai_unit_test_env) -> None:
# Test successful initialization
with pytest.raises(ServiceInitializationError):
MistralAIChatCompletion(
ai_model_id=ai_model_id,
env_file_path="test.env",
api_key="overwrite_api_key",
env_file_path="test.env"
)


@pytest.mark.parametrize("exclude_list", [["MISTRALAI_API_KEY", "MISTRALAI_CHAT_MODEL_ID"]], indirect=True)
def test_mistral_ai_chat_completion_init_constructor_missing_api_key(mistralai_unit_test_env) -> None:
# Test successful initialization
with pytest.raises(ServiceInitializationError):
MistralAIChatCompletion(
ai_model_id="overwrite_model_id",
env_file_path="test.env"
)


def test_mistral_ai_chat_completion_init_hybrid(mistralai_unit_test_env) -> None:
mistral_ai_chat_completion = MistralAIChatCompletion(
ai_model_id="overwrite_model_id",
env_file_path="test.env",
)
assert mistral_ai_chat_completion.ai_model_id == "overwrite_model_id"
assert mistral_ai_chat_completion.async_client._api_key == "test_api_key"


@pytest.mark.parametrize("exclude_list", [["MISTRALAI_CHAT_MODEL_ID"]], indirect=True)
def test_mistral_ai_chat_completion_init_with_empty_model_id(mistralai_unit_test_env) -> None:
with pytest.raises(ServiceInitializationError):
Expand Down
Loading
Loading