diff --git a/.github/workflows/python-integration-tests.yml b/.github/workflows/python-integration-tests.yml index f06e527d539b..85f3890f7be5 100644 --- a/.github/workflows/python-integration-tests.yml +++ b/.github/workflows/python-integration-tests.yml @@ -123,6 +123,7 @@ jobs: ACA_POOL_MANAGEMENT_ENDPOINT: ${{secrets.ACA_POOL_MANAGEMENT_ENDPOINT}} MISTRALAI_API_KEY: ${{secrets.MISTRALAI_API_KEY}} MISTRALAI_CHAT_MODEL_ID: ${{ vars.MISTRALAI_CHAT_MODEL_ID }} + MISTRALAI_EMBEDDING_MODEL_ID: ${{ vars.MISTRALAI_EMBEDDING_MODEL_ID }} OLLAMA_MODEL: "${{ matrix.os == 'ubuntu-latest' && vars.OLLAMA_MODEL || '' }}" # phi3 GOOGLE_AI_GEMINI_MODEL_ID: ${{ vars.GOOGLE_AI_GEMINI_MODEL_ID }} GOOGLE_AI_EMBEDDING_MODEL_ID: ${{ vars.GOOGLE_AI_EMBEDDING_MODEL_ID }} @@ -233,6 +234,7 @@ jobs: ACA_POOL_MANAGEMENT_ENDPOINT: ${{secrets.ACA_POOL_MANAGEMENT_ENDPOINT}} MISTRALAI_API_KEY: ${{secrets.MISTRALAI_API_KEY}} MISTRALAI_CHAT_MODEL_ID: ${{ vars.MISTRALAI_CHAT_MODEL_ID }} + MISTRALAI_EMBEDDING_MODEL_ID: ${{ vars.MISTRALAI_EMBEDDING_MODEL_ID }} OLLAMA_MODEL: "${{ matrix.os == 'ubuntu-latest' && vars.OLLAMA_MODEL || '' }}" # phi3 GOOGLE_AI_GEMINI_MODEL_ID: ${{ vars.GOOGLE_AI_GEMINI_MODEL_ID }} GOOGLE_AI_EMBEDDING_MODEL_ID: ${{ vars.GOOGLE_AI_EMBEDDING_MODEL_ID }} diff --git a/python/semantic_kernel/connectors/ai/mistral_ai/__init__.py b/python/semantic_kernel/connectors/ai/mistral_ai/__init__.py index 9b2d7d379066..8dc0c473a53f 100644 --- a/python/semantic_kernel/connectors/ai/mistral_ai/__init__.py +++ b/python/semantic_kernel/connectors/ai/mistral_ai/__init__.py @@ -4,8 +4,10 @@ MistralAIChatPromptExecutionSettings, ) from semantic_kernel.connectors.ai.mistral_ai.services.mistral_ai_chat_completion import MistralAIChatCompletion +from semantic_kernel.connectors.ai.mistral_ai.services.mistral_ai_text_embedding import MistralAITextEmbedding __all__ = [ "MistralAIChatCompletion", "MistralAIChatPromptExecutionSettings", + "MistralAITextEmbedding", ] diff --git a/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py b/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py new file mode 100644 index 000000000000..24b2905b1587 --- /dev/null +++ b/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py @@ -0,0 +1,112 @@ +# Copyright (c) Microsoft. All rights reserved. + +import sys + +if sys.version_info >= (3, 12): + from typing import Any, override # pragma: no cover +else: + from typing_extensions import Any, override # pragma: no cover +import logging + +from mistralai.async_client import MistralAsyncClient +from mistralai.models.embeddings import EmbeddingResponse +from numpy import array, ndarray +from pydantic import ValidationError + +from semantic_kernel.connectors.ai.embeddings.embedding_generator_base import EmbeddingGeneratorBase +from semantic_kernel.connectors.ai.mistral_ai.settings.mistral_ai_settings import MistralAISettings +from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings +from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError, ServiceResponseException +from semantic_kernel.utils.experimental_decorator import experimental_class + +logger: logging.Logger = logging.getLogger(__name__) + + +@experimental_class +class MistralAITextEmbedding(EmbeddingGeneratorBase): + """Mistral AI Inference Text Embedding Service.""" + + client: MistralAsyncClient + + def __init__( + self, + ai_model_id: str | None = None, + api_key: str | None = None, + service_id: str | None = None, + env_file_path: str | None = None, + env_file_encoding: str | None = None, + client: MistralAsyncClient | None = None, + ) -> None: + """Initialize the Mistral AI Text Embedding service. + + If no arguments are provided, the service will attempt to load the settings from the environment. + The following environment variables are used: + - MISTRALAI_API_KEY + - MISTRALAI_EMBEDDING_MODEL_ID + + Args: + ai_model_id: (str | None): A string that is used to identify the model such as the model name. + api_key (str | None): The API key for the Mistral AI service deployment. + service_id (str | None): Service ID for the embedding completion service. + env_file_path (str | None): The path to the environment file. + env_file_encoding (str | None): The encoding of the environment file. + client (MistralAsyncClient | None): The Mistral AI client to use. + + Raises: + ServiceInitializationError: If an error occurs during initialization. + """ + try: + mistralai_settings = MistralAISettings.create( + api_key=api_key, + embedding_model_id=ai_model_id, + env_file_path=env_file_path, + env_file_encoding=env_file_encoding, + ) + except ValidationError as e: + raise ServiceInitializationError(f"Failed to validate Mistral AI settings: {e}") from e + + if not mistralai_settings.embedding_model_id: + raise ServiceInitializationError("The MistralAI embedding model ID is required.") + + if not client: + client = MistralAsyncClient( + api_key=mistralai_settings.api_key.get_secret_value() + ) + + super().__init__( + service_id=service_id or mistralai_settings.embedding_model_id, + ai_model_id=ai_model_id or mistralai_settings.embedding_model_id, + client=client, + ) + + @override + async def generate_embeddings( + self, + texts: list[str], + settings: "PromptExecutionSettings | None" = None, + **kwargs: Any, + ) -> ndarray: + embedding_response = await self.generate_raw_embeddings(texts, settings, **kwargs) + return array(embedding_response) + + @override + async def generate_raw_embeddings( + self, + texts: list[str], + settings: "PromptExecutionSettings | None" = None, + **kwargs: Any, + ) -> Any: + """Generate embeddings from the Mistral AI service.""" + try: + + embedding_response: EmbeddingResponse = await self.client.embeddings( + model=self.ai_model_id, + input=texts + ) + except Exception as ex: + raise ServiceResponseException( + f"{type(self)} service failed to complete the embedding request.", + ex, + ) from ex + + return [item.embedding for item in embedding_response.data] diff --git a/python/semantic_kernel/connectors/ai/mistral_ai/settings/mistral_ai_settings.py b/python/semantic_kernel/connectors/ai/mistral_ai/settings/mistral_ai_settings.py index 8139be0ba568..8acd90148d69 100644 --- a/python/semantic_kernel/connectors/ai/mistral_ai/settings/mistral_ai_settings.py +++ b/python/semantic_kernel/connectors/ai/mistral_ai/settings/mistral_ai_settings.py @@ -20,6 +20,8 @@ class MistralAISettings(KernelBaseSettings): (Env var MISTRALAI_API_KEY) - chat_model_id: str | None - The The Mistral AI chat model ID to use see https://docs.mistral.ai/getting-started/models/. (Env var MISTRALAI_CHAT_MODEL_ID) + - embedding_model_id: str | None - The The Mistral AI embedding model ID to use see https://docs.mistral.ai/getting-started/models/. + (Env var MISTRALAI_EMBEDDING_MODEL_ID) - env_file_path: str | None - if provided, the .env settings are read from this file path location """ @@ -27,3 +29,4 @@ class MistralAISettings(KernelBaseSettings): api_key: SecretStr chat_model_id: str | None = None + embedding_model_id: str | None = None diff --git a/python/tests/conftest.py b/python/tests/conftest.py index bff9b5eae24c..41575b36e337 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -279,7 +279,11 @@ def mistralai_unit_test_env(monkeypatch, exclude_list, override_env_param_dict): if override_env_param_dict is None: override_env_param_dict = {} - env_vars = {"MISTRALAI_CHAT_MODEL_ID": "test_chat_model_id", "MISTRALAI_API_KEY": "test_api_key"} + env_vars = { + "MISTRALAI_CHAT_MODEL_ID": "test_chat_model_id", + "MISTRALAI_API_KEY": "test_api_key", + "MISTRALAI_EMBEDDING_MODEL_ID": "test_embedding_model_id", + } env_vars.update(override_env_param_dict) diff --git a/python/tests/integration/embeddings/test_embedding_service.py b/python/tests/integration/embeddings/test_embedding_service.py new file mode 100644 index 000000000000..60917da23dca --- /dev/null +++ b/python/tests/integration/embeddings/test_embedding_service.py @@ -0,0 +1,79 @@ +# Copyright (c) Microsoft. All rights reserved. + +import os + +import pytest + +import semantic_kernel as sk +from semantic_kernel.connectors.ai.embeddings.embedding_generator_base import EmbeddingGeneratorBase +from semantic_kernel.connectors.ai.mistral_ai import MistralAITextEmbedding +from semantic_kernel.core_plugins.text_memory_plugin import TextMemoryPlugin +from semantic_kernel.kernel import Kernel +from semantic_kernel.memory.semantic_text_memory import SemanticTextMemory + +mistral_ai_setup: bool = False +try: + if os.environ["MISTRALAI_API_KEY"] and os.environ["MISTRALAI_EMBEDDING_MODEL_ID"]: + mistral_ai_setup = True +except KeyError: + mistral_ai_setup = False + + +pytestmark = pytest.mark.parametrize("embeddings_generator", + [ + pytest.param( + MistralAITextEmbedding() if mistral_ai_setup else None, + marks=pytest.mark.skipif(not mistral_ai_setup, reason="Mistral AI environment variables not set"), + id="MistralEmbeddings" + ) + ] +) + + +@pytest.mark.asyncio(scope="module") +async def test_embedding_service(kernel: Kernel, embeddings_generator: EmbeddingGeneratorBase): + kernel.add_service(embeddings_generator) + + memory = SemanticTextMemory(storage=sk.memory.VolatileMemoryStore(), embeddings_generator=embeddings_generator) + kernel.add_plugin(TextMemoryPlugin(memory), "TextMemoryPlugin") + + await memory.save_reference( + "test", + external_id="info1", + text="this is a test", + external_source_name="external source", + ) + + # Add some documents to the semantic memory + await memory.save_information("test", id="info1", text="Sharks are fish.") + await memory.save_information("test", id="info2", text="Whales are mammals.") + await memory.save_information("test", id="info3", text="Penguins are birds.") + await memory.save_information("test", id="info4", text="Dolphins are mammals.") + await memory.save_information("test", id="info5", text="Flies are insects.") + + # Search for documents + query = "What are mammals?" + result = await memory.search("test", query, limit=2, min_relevance_score=0.0) + print(f"Query: {query}") + print(f"\tAnswer 1: {result[0].text}") + print(f"\tAnswer 2: {result[1].text}\n") + assert "mammals." in result[0].text + assert "mammals." in result[1].text + + query = "What are fish?" + result = await memory.search("test", query, limit=1, min_relevance_score=0.0) + print(f"Query: {query}") + print(f"\tAnswer: {result[0].text}\n") + assert result[0].text == "Sharks are fish." + + query = "What are insects?" + result = await memory.search("test", query, limit=1, min_relevance_score=0.0) + print(f"Query: {query}") + print(f"\tAnswer: {result[0].text}\n") + assert result[0].text == "Flies are insects." + + query = "What are birds?" + result = await memory.search("test", query, limit=1, min_relevance_score=0.0) + print(f"Query: {query}") + print(f"\tAnswer: {result[0].text}\n") + assert result[0].text == "Penguins are birds." diff --git a/python/tests/unit/connectors/mistral_ai/services/test_mistralai_chat_completion.py b/python/tests/unit/connectors/mistral_ai/services/test_mistralai_chat_completion.py index 1fe0a868a9ff..5a4e18521c5d 100644 --- a/python/tests/unit/connectors/mistral_ai/services/test_mistralai_chat_completion.py +++ b/python/tests/unit/connectors/mistral_ai/services/test_mistralai_chat_completion.py @@ -130,20 +130,53 @@ def test_mistral_ai_chat_completion_init(mistralai_unit_test_env) -> None: mistral_ai_chat_completion = MistralAIChatCompletion() assert mistral_ai_chat_completion.ai_model_id == mistralai_unit_test_env["MISTRALAI_CHAT_MODEL_ID"] + assert mistral_ai_chat_completion.async_client._api_key == mistralai_unit_test_env["MISTRALAI_API_KEY"] assert isinstance(mistral_ai_chat_completion, ChatCompletionClientBase) -@pytest.mark.parametrize("exclude_list", [["MISTRALAI_API_KEY"]], indirect=True) -def test_mistral_ai_chat_completion_init_with_empty_api_key(mistralai_unit_test_env) -> None: - ai_model_id = "test_model_id" +@pytest.mark.parametrize("exclude_list", [["MISTRALAI_API_KEY", "MISTRALAI_CHAT_MODEL_ID"]], indirect=True) +def test_mistral_ai_chat_completion_init_constructor(mistralai_unit_test_env) -> None: + # Test successful initialization + mistral_ai_chat_completion = MistralAIChatCompletion( + api_key="overwrite_api_key", + ai_model_id="overwrite_model_id", + env_file_path="test.env", + ) + + assert mistral_ai_chat_completion.ai_model_id == "overwrite_model_id" + assert mistral_ai_chat_completion.async_client._api_key == "overwrite_api_key" + assert isinstance(mistral_ai_chat_completion, ChatCompletionClientBase) + +@pytest.mark.parametrize("exclude_list", [["MISTRALAI_API_KEY", "MISTRALAI_CHAT_MODEL_ID"]], indirect=True) +def test_mistral_ai_chat_completion_init_constructor_missing_model(mistralai_unit_test_env) -> None: + # Test successful initialization with pytest.raises(ServiceInitializationError): MistralAIChatCompletion( - ai_model_id=ai_model_id, - env_file_path="test.env", + api_key="overwrite_api_key", + env_file_path="test.env" ) +@pytest.mark.parametrize("exclude_list", [["MISTRALAI_API_KEY", "MISTRALAI_CHAT_MODEL_ID"]], indirect=True) +def test_mistral_ai_chat_completion_init_constructor_missing_api_key(mistralai_unit_test_env) -> None: + # Test successful initialization + with pytest.raises(ServiceInitializationError): + MistralAIChatCompletion( + ai_model_id="overwrite_model_id", + env_file_path="test.env" + ) + + +def test_mistral_ai_chat_completion_init_hybrid(mistralai_unit_test_env) -> None: + mistral_ai_chat_completion = MistralAIChatCompletion( + ai_model_id="overwrite_model_id", + env_file_path="test.env", + ) + assert mistral_ai_chat_completion.ai_model_id == "overwrite_model_id" + assert mistral_ai_chat_completion.async_client._api_key == "test_api_key" + + @pytest.mark.parametrize("exclude_list", [["MISTRALAI_CHAT_MODEL_ID"]], indirect=True) def test_mistral_ai_chat_completion_init_with_empty_model_id(mistralai_unit_test_env) -> None: with pytest.raises(ServiceInitializationError): diff --git a/python/tests/unit/connectors/mistral_ai/services/test_mistralai_text_embeddings.py b/python/tests/unit/connectors/mistral_ai/services/test_mistralai_text_embeddings.py new file mode 100644 index 000000000000..98550ca6f1ad --- /dev/null +++ b/python/tests/unit/connectors/mistral_ai/services/test_mistralai_text_embeddings.py @@ -0,0 +1,114 @@ +# Copyright (c) Microsoft. All rights reserved. + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from mistralai.async_client import MistralAsyncClient +from mistralai.models.embeddings import EmbeddingResponse + +from semantic_kernel.connectors.ai.mistral_ai.services.mistral_ai_text_embedding import MistralAITextEmbedding +from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError, ServiceResponseException + + +def test_embedding_with_env_variables(mistralai_unit_test_env): + text_embedding = MistralAITextEmbedding() + assert text_embedding.ai_model_id == "test_embedding_model_id" + assert text_embedding.client._api_key == "test_api_key" + + +@pytest.mark.parametrize("exclude_list", [["MISTRALAI_API_KEY", "MISTRALAI_EMBEDDING_MODEL_ID"]], indirect=True) +def test_embedding_with_constructor(mistralai_unit_test_env): + text_embedding = MistralAITextEmbedding( + api_key="overwrite-api-key", + ai_model_id="overwrite-model", + ) + assert text_embedding.ai_model_id == "overwrite-model" + assert text_embedding.client._api_key == "overwrite-api-key" + + +def test_embedding_with_client(mistralai_unit_test_env): + client = MagicMock(spec=MistralAsyncClient) + text_embedding = MistralAITextEmbedding(client=client) + assert text_embedding.client == client + assert text_embedding.ai_model_id == "test_embedding_model_id" + + +def test_embedding_with_api_key(mistralai_unit_test_env): + text_embedding = MistralAITextEmbedding(api_key="overwrite-api-key") + assert text_embedding.client._api_key == "overwrite-api-key" + assert text_embedding.ai_model_id == "test_embedding_model_id" + + +def test_embedding_with_model(mistralai_unit_test_env): + text_embedding = MistralAITextEmbedding(ai_model_id="overwrite-model") + assert text_embedding.ai_model_id == "overwrite-model" + assert text_embedding.client._api_key == "test_api_key" + + +@pytest.mark.parametrize("exclude_list", [["MISTRALAI_EMBEDDING_MODEL_ID"]], indirect=True) +def test_embedding_with_model_without_env(mistralai_unit_test_env): + text_embedding = MistralAITextEmbedding(ai_model_id="overwrite-model") + assert text_embedding.ai_model_id == "overwrite-model" + assert text_embedding.client._api_key == "test_api_key" + + +@pytest.mark.parametrize("exclude_list", [["MISTRALAI_EMBEDDING_MODEL_ID"]], indirect=True) +def test_embedding_missing_model(mistralai_unit_test_env): + with pytest.raises(ServiceInitializationError): + MistralAITextEmbedding( + env_file_path="test.env", + ) + + +@pytest.mark.parametrize("exclude_list", [["MISTRALAI_API_KEY"]], indirect=True) +def test_embedding_missing_api_key(mistralai_unit_test_env): + with pytest.raises(ServiceInitializationError): + MistralAITextEmbedding( + env_file_path="test.env", + ) + + +@pytest.mark.parametrize("exclude_list", [["MISTRALAI_API_KEY", "MISTRALAI_EMBEDDING_MODEL_ID"]], indirect=True) +def test_embedding_missing_api_key_constructor(mistralai_unit_test_env): + with pytest.raises(ServiceInitializationError): + MistralAITextEmbedding( + env_file_path="test.env", + ) + + +@pytest.mark.parametrize("exclude_list", [["MISTRALAI_API_KEY", "MISTRALAI_EMBEDDING_MODEL_ID"]], indirect=True) +def test_embedding_missing_model_constructor(mistralai_unit_test_env): + with pytest.raises(ServiceInitializationError): + MistralAITextEmbedding( + api_key="test_api_key", + env_file_path="test.env", + ) + + +@pytest.mark.asyncio +async def test_embedding_generate_raw_embedding(mistralai_unit_test_env): + mock_client = AsyncMock(spec=MistralAsyncClient) + mock_embedding_response = MagicMock(spec=EmbeddingResponse, data=[MagicMock(embedding=[1, 2, 3, 4, 5])]) + mock_client.embeddings.return_value = mock_embedding_response + text_embedding = MistralAITextEmbedding(client=mock_client) + embedding = await text_embedding.generate_raw_embeddings(["test"]) + assert embedding == [[1, 2, 3, 4, 5]] + + +@pytest.mark.asyncio +async def test_embedding_generate_embedding(mistralai_unit_test_env): + mock_client = AsyncMock(spec=MistralAsyncClient) + mock_embedding_response = MagicMock(spec=EmbeddingResponse, data=[MagicMock(embedding=[1, 2, 3, 4, 5])]) + mock_client.embeddings.return_value = mock_embedding_response + text_embedding = MistralAITextEmbedding(client=mock_client) + embedding = await text_embedding.generate_embeddings(["test"]) + assert embedding.tolist() == [[1, 2, 3, 4, 5]] + + +@pytest.mark.asyncio +async def test_embedding_generate_embedding_exception(mistralai_unit_test_env): + mock_client = AsyncMock(spec=MistralAsyncClient) + mock_client.embeddings.side_effect = Exception("Test Exception") + text_embedding = MistralAITextEmbedding(client=mock_client) + with pytest.raises(ServiceResponseException): + await text_embedding.generate_embeddings(["test"])