From 1304cc18ebe26afc317aa92b143e5876dc006557 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nico=20M=C3=B6ller?= Date: Sat, 6 Jul 2024 11:34:48 +0200 Subject: [PATCH 1/8] Initial Commit for Mistral Embeddings with TestCases --- .../connectors/ai/mistral_ai/__init__.py | 2 + .../services/mistral_ai_text_embedding.py | 88 +++++++++++++++++++ .../settings/mistral_ai_settings.py | 3 + python/tests/conftest.py | 3 +- .../embeddings/test_embedding_service.py | 79 +++++++++++++++++ .../test_mistralai_text_embeddings.py | 67 ++++++++++++++ 6 files changed, 241 insertions(+), 1 deletion(-) create mode 100644 python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py create mode 100644 python/tests/integration/embeddings/test_embedding_service.py create mode 100644 python/tests/unit/connectors/mistral_ai/services/test_mistralai_text_embeddings.py diff --git a/python/semantic_kernel/connectors/ai/mistral_ai/__init__.py b/python/semantic_kernel/connectors/ai/mistral_ai/__init__.py index 9b2d7d379066..8dc0c473a53f 100644 --- a/python/semantic_kernel/connectors/ai/mistral_ai/__init__.py +++ b/python/semantic_kernel/connectors/ai/mistral_ai/__init__.py @@ -4,8 +4,10 @@ MistralAIChatPromptExecutionSettings, ) from semantic_kernel.connectors.ai.mistral_ai.services.mistral_ai_chat_completion import MistralAIChatCompletion +from semantic_kernel.connectors.ai.mistral_ai.services.mistral_ai_text_embedding import MistralAITextEmbedding __all__ = [ "MistralAIChatCompletion", "MistralAIChatPromptExecutionSettings", + "MistralAITextEmbedding", ] diff --git a/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py b/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py new file mode 100644 index 000000000000..600a5c81f730 --- /dev/null +++ b/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py @@ -0,0 +1,88 @@ +# Copyright (c) Microsoft. All rights reserved. + + +from typing import Any + +from mistralai.async_client import MistralAsyncClient +from mistralai.models.embeddings import EmbeddingResponse +from numpy import array, ndarray +from pydantic import ValidationError + +from semantic_kernel.connectors.ai.embeddings.embedding_generator_base import EmbeddingGeneratorBase +from semantic_kernel.connectors.ai.mistral_ai.settings.mistral_ai_settings import MistralAISettings +from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError, ServiceResponseException +from semantic_kernel.utils.experimental_decorator import experimental_class + + +@experimental_class +class MistralAITextEmbedding(EmbeddingGeneratorBase): + """Mistral AI Inference Text Embedding Service.""" + + client: MistralAsyncClient + + def __init__( + self, + ai_model_id: str | None = None, + api_key: str | None = None, + service_id: str | None = None, + env_file_path: str | None = None, + env_file_encoding: str | None = None, + client: MistralAsyncClient | None = None, + ) -> None: + """Initialize the Mistral AI Text Embedding service. + + If no arguments are provided, the service will attempt to load the settings from the environment. + The following environment variables are used: + - MISTRALAI_API_KEY + - MISTRALAI_EMBEDDING_MODEL_ID + + Args: + ai_model_id: (str | None): A string that is used to identify the model such as the model name. + api_key (str | None): The API key for the Mistral AI service deployment. + service_id (str | None): Service ID for the embedding completion service. + env_file_path (str | None): The path to the environment file. + env_file_encoding (str | None): The encoding of the environment file. + client (MistralAsyncClient | None): The Mistral AI client to use. + + Raises: + ServiceInitializationError: If an error occurs during initialization. + """ + try: + mistralai_settings = MistralAISettings.create( + api_key=api_key, + embedding_model_id=ai_model_id, + env_file_path=env_file_path, + env_file_encoding=env_file_encoding, + ) + except ValidationError as e: + raise ServiceInitializationError(f"Failed to validate Mistral AI settings: {e}") from e + + if not mistralai_settings.chat_model_id: + raise ServiceInitializationError("The MistralAI embedding model ID is required.") + + if not client: + client = MistralAsyncClient( + api_key=mistralai_settings.api_key.get_secret_value() + ) + + super().__init__( + service_id=service_id or mistralai_settings.embedding_model_id, + ai_model_id=ai_model_id or mistralai_settings.embedding_model_id, + client=client, + ) + + async def generate_embeddings(self, texts: list[str], **kwargs: Any) -> ndarray: + """Generate embeddings from the Mistral AI service.""" + try: + + response: EmbeddingResponse = await self.client.embeddings( + model=self.ai_model_id, + input=texts + ) + except Exception as ex: + raise ServiceResponseException( + f"{type(self)} service failed to complete the embedding request.", + ex, + ) from ex + + return array([array(item.embedding) for item in response.data]) diff --git a/python/semantic_kernel/connectors/ai/mistral_ai/settings/mistral_ai_settings.py b/python/semantic_kernel/connectors/ai/mistral_ai/settings/mistral_ai_settings.py index 8139be0ba568..8acd90148d69 100644 --- a/python/semantic_kernel/connectors/ai/mistral_ai/settings/mistral_ai_settings.py +++ b/python/semantic_kernel/connectors/ai/mistral_ai/settings/mistral_ai_settings.py @@ -20,6 +20,8 @@ class MistralAISettings(KernelBaseSettings): (Env var MISTRALAI_API_KEY) - chat_model_id: str | None - The The Mistral AI chat model ID to use see https://docs.mistral.ai/getting-started/models/. (Env var MISTRALAI_CHAT_MODEL_ID) + - embedding_model_id: str | None - The The Mistral AI embedding model ID to use see https://docs.mistral.ai/getting-started/models/. + (Env var MISTRALAI_EMBEDDING_MODEL_ID) - env_file_path: str | None - if provided, the .env settings are read from this file path location """ @@ -27,3 +29,4 @@ class MistralAISettings(KernelBaseSettings): api_key: SecretStr chat_model_id: str | None = None + embedding_model_id: str | None = None diff --git a/python/tests/conftest.py b/python/tests/conftest.py index e5481f1cb445..68eb804407ab 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -260,7 +260,8 @@ def mistralai_unit_test_env(monkeypatch, exclude_list, override_env_param_dict): env_vars = { "MISTRALAI_CHAT_MODEL_ID": "test_chat_model_id", - "MISTRALAI_API_KEY": "test_api_key" + "MISTRALAI_API_KEY": "test_api_key", + "MISTRALAI_EMBEDDING_MODEL_ID": "test_embedding_model_id", } env_vars.update(override_env_param_dict) diff --git a/python/tests/integration/embeddings/test_embedding_service.py b/python/tests/integration/embeddings/test_embedding_service.py new file mode 100644 index 000000000000..c6151c53441d --- /dev/null +++ b/python/tests/integration/embeddings/test_embedding_service.py @@ -0,0 +1,79 @@ +# Copyright (c) Microsoft. All rights reserved. + +import os + +import pytest + +import semantic_kernel as sk +from semantic_kernel.connectors.ai.embeddings.embedding_generator_base import EmbeddingGeneratorBase +from semantic_kernel.connectors.ai.mistral_ai import MistralAITextEmbedding +from semantic_kernel.core_plugins.text_memory_plugin import TextMemoryPlugin +from semantic_kernel.kernel import Kernel +from semantic_kernel.memory.semantic_text_memory import SemanticTextMemory + +mistral_ai_setup: bool = False +try: + if os.environ["MISTRALAI_API_KEY"] and os.environ["MISTRALAI_CHAT_MODEL_ID"]: + mistral_ai_setup = True +except KeyError: + mistral_ai_setup = False + + +pytestmark = pytest.mark.parametrize("embeddings_generator", + [ + pytest.param( + MistralAITextEmbedding() if mistral_ai_setup else None, + marks=pytest.mark.skipif(not mistral_ai_setup, reason="Mistral AI environment variables not set"), + id="MistralEmbeddings" + ) + ] +) + + +@pytest.mark.asyncio(scope="module") +async def test_embedding_service(kernel: Kernel, embeddings_generator: EmbeddingGeneratorBase): + kernel.add_service(embeddings_generator) + + memory = SemanticTextMemory(storage=sk.memory.VolatileMemoryStore(), embeddings_generator=embeddings_generator) + kernel.add_plugin(TextMemoryPlugin(memory), "TextMemoryPlugin") + + await memory.save_reference( + "test", + external_id="info1", + text="this is a test", + external_source_name="external source", + ) + + # Add some documents to the semantic memory + await memory.save_information("test", id="info1", text="Sharks are fish.") + await memory.save_information("test", id="info2", text="Whales are mammals.") + await memory.save_information("test", id="info3", text="Penguins are birds.") + await memory.save_information("test", id="info4", text="Dolphins are mammals.") + await memory.save_information("test", id="info5", text="Flies are insects.") + + # Search for documents + query = "What are mammals?" + result = await memory.search("test", query, limit=2, min_relevance_score=0.0) + print(f"Query: {query}") + print(f"\tAnswer 1: {result[0].text}") + print(f"\tAnswer 2: {result[1].text}\n") + assert "mammals." in result[0].text + assert "mammals." in result[1].text + + query = "What are fish?" + result = await memory.search("test", query, limit=1, min_relevance_score=0.0) + print(f"Query: {query}") + print(f"\tAnswer: {result[0].text}\n") + assert result[0].text == "Sharks are fish." + + query = "What are insects?" + result = await memory.search("test", query, limit=1, min_relevance_score=0.0) + print(f"Query: {query}") + print(f"\tAnswer: {result[0].text}\n") + assert result[0].text == "Flies are insects." + + query = "What are birds?" + result = await memory.search("test", query, limit=1, min_relevance_score=0.0) + print(f"Query: {query}") + print(f"\tAnswer: {result[0].text}\n") + assert result[0].text == "Penguins are birds." diff --git a/python/tests/unit/connectors/mistral_ai/services/test_mistralai_text_embeddings.py b/python/tests/unit/connectors/mistral_ai/services/test_mistralai_text_embeddings.py new file mode 100644 index 000000000000..25d93e3d83c8 --- /dev/null +++ b/python/tests/unit/connectors/mistral_ai/services/test_mistralai_text_embeddings.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft. All rights reserved. + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from mistralai.async_client import MistralAsyncClient + +from semantic_kernel.connectors.ai.mistral_ai.services.mistral_ai_text_embedding import MistralAITextEmbedding +from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError, ServiceResponseException + + +def test_embedding_with_env_variables(mistralai_unit_test_env): + text_embedding = MistralAITextEmbedding() + assert text_embedding.ai_model_id == "test_embedding_model_id" + assert text_embedding.client._api_key == "test_api_key" + + +def test_embedding_with_client(mistralai_unit_test_env): + client = MagicMock(spec=MistralAsyncClient) + text_embedding = MistralAITextEmbedding(client=client) + assert text_embedding.client == client + assert text_embedding.ai_model_id == "test_embedding_model_id" + + +def test_embedding_with_api_key(mistralai_unit_test_env): + text_embedding = MistralAITextEmbedding(api_key="overwrite-api-key") + assert text_embedding.client._api_key == "overwrite-api-key" + + +def test_embedding_with_model(mistralai_unit_test_env): + text_embedding = MistralAITextEmbedding(ai_model_id="overwrite-model") + assert text_embedding.ai_model_id == "overwrite-model" + + +@pytest.mark.parametrize("exclude_list", [["MISTRALAI_CHAT_MODEL_ID"]], indirect=True) +def test_embedding_missing_model(mistralai_unit_test_env): + with pytest.raises(ServiceInitializationError): + MistralAITextEmbedding( + env_file_path="test.env", + ) + + +@pytest.mark.parametrize("exclude_list", [["MISTRALAI_API_KEY"]], indirect=True) +def test_embedding_missing_api_key(mistralai_unit_test_env): + with pytest.raises(ServiceInitializationError): + MistralAITextEmbedding( + env_file_path="test.env", + ) + + +@pytest.mark.asyncio +async def test_embedding_generate_embedding(mistralai_unit_test_env): + mock_client = AsyncMock(spec=MistralAsyncClient) + mock_embedding_response = MagicMock(data=[MagicMock(embedding=[1, 2, 3, 4, 5])]) + mock_client.embeddings.return_value = mock_embedding_response + text_embedding = MistralAITextEmbedding(client=mock_client) + embedding = await text_embedding.generate_embeddings(["test"]) + assert embedding.tolist() == [[1, 2, 3, 4, 5]] + + +@pytest.mark.asyncio +async def test_embedding_generate_embedding_exception(mistralai_unit_test_env): + mock_client = AsyncMock(spec=MistralAsyncClient) + mock_client.embeddings.side_effect = Exception("Test Exception") + text_embedding = MistralAITextEmbedding(client=mock_client) + with pytest.raises(ServiceResponseException): + await text_embedding.generate_embeddings(["test"]) From 8dbf320b744df876499d00fdfc9f7dac09bb1eba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nico=20M=C3=B6ller?= Date: Mon, 8 Jul 2024 08:46:43 +0200 Subject: [PATCH 2/8] small naming fixes more testcases --- .../workflows/python-integration-tests.yml | 2 ++ .../services/mistral_ai_text_embedding.py | 2 +- .../embeddings/test_embedding_service.py | 2 +- .../test_mistralai_text_embeddings.py | 31 ++++++++++++++++++- 4 files changed, 34 insertions(+), 3 deletions(-) diff --git a/.github/workflows/python-integration-tests.yml b/.github/workflows/python-integration-tests.yml index 076c66b3368a..016a9dec267b 100644 --- a/.github/workflows/python-integration-tests.yml +++ b/.github/workflows/python-integration-tests.yml @@ -98,6 +98,7 @@ jobs: ACA_POOL_MANAGEMENT_ENDPOINT: ${{secrets.ACA_POOL_MANAGEMENT_ENDPOINT}} MISTRALAI_API_KEY: ${{secrets.MISTRALAI_API_KEY}} MISTRALAI_CHAT_MODEL_ID: ${{ vars.MISTRALAI_CHAT_MODEL_ID }} + MISTRALAI_EMBEDDING_MODEL_ID: ${{ vars.MISTRALAI_EMBEDDING_MODEL_ID }} run: | if ${{ matrix.os == 'ubuntu-latest' }}; then docker run -d --name redis-stack-server -p 6379:6379 redis/redis-stack-server:latest @@ -167,6 +168,7 @@ jobs: ACA_POOL_MANAGEMENT_ENDPOINT: ${{secrets.ACA_POOL_MANAGEMENT_ENDPOINT}} MISTRALAI_API_KEY: ${{secrets.MISTRALAI_API_KEY}} MISTRALAI_CHAT_MODEL_ID: ${{ vars.MISTRALAI_CHAT_MODEL_ID }} + MISTRALAI_EMBEDDING_MODEL_ID: ${{ vars.MISTRALAI_EMBEDDING_MODEL_ID }} run: | if ${{ matrix.os == 'ubuntu-latest' }}; then docker run -d --name redis-stack-server -p 6379:6379 redis/redis-stack-server:latest diff --git a/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py b/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py index 600a5c81f730..f4bea0e6b40c 100644 --- a/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py +++ b/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py @@ -57,7 +57,7 @@ def __init__( except ValidationError as e: raise ServiceInitializationError(f"Failed to validate Mistral AI settings: {e}") from e - if not mistralai_settings.chat_model_id: + if not mistralai_settings.embedding_model_id: raise ServiceInitializationError("The MistralAI embedding model ID is required.") if not client: diff --git a/python/tests/integration/embeddings/test_embedding_service.py b/python/tests/integration/embeddings/test_embedding_service.py index c6151c53441d..60917da23dca 100644 --- a/python/tests/integration/embeddings/test_embedding_service.py +++ b/python/tests/integration/embeddings/test_embedding_service.py @@ -13,7 +13,7 @@ mistral_ai_setup: bool = False try: - if os.environ["MISTRALAI_API_KEY"] and os.environ["MISTRALAI_CHAT_MODEL_ID"]: + if os.environ["MISTRALAI_API_KEY"] and os.environ["MISTRALAI_EMBEDDING_MODEL_ID"]: mistral_ai_setup = True except KeyError: mistral_ai_setup = False diff --git a/python/tests/unit/connectors/mistral_ai/services/test_mistralai_text_embeddings.py b/python/tests/unit/connectors/mistral_ai/services/test_mistralai_text_embeddings.py index 25d93e3d83c8..25ee5203db90 100644 --- a/python/tests/unit/connectors/mistral_ai/services/test_mistralai_text_embeddings.py +++ b/python/tests/unit/connectors/mistral_ai/services/test_mistralai_text_embeddings.py @@ -15,6 +15,16 @@ def test_embedding_with_env_variables(mistralai_unit_test_env): assert text_embedding.client._api_key == "test_api_key" +@pytest.mark.parametrize("exclude_list", [["MISTRALAI_API_KEY", "MISTRALAI_EMBEDDING_MODEL_ID"]], indirect=True) +def test_embedding_with_constructor(mistralai_unit_test_env): + text_embedding = MistralAITextEmbedding( + api_key="overwrite-api-key", + ai_model_id="overwrite-model", + ) + assert text_embedding.ai_model_id == "overwrite-model" + assert text_embedding.client._api_key == "overwrite-api-key" + + def test_embedding_with_client(mistralai_unit_test_env): client = MagicMock(spec=MistralAsyncClient) text_embedding = MistralAITextEmbedding(client=client) @@ -25,14 +35,16 @@ def test_embedding_with_client(mistralai_unit_test_env): def test_embedding_with_api_key(mistralai_unit_test_env): text_embedding = MistralAITextEmbedding(api_key="overwrite-api-key") assert text_embedding.client._api_key == "overwrite-api-key" + assert text_embedding.ai_model_id == "test_embedding_model_id" def test_embedding_with_model(mistralai_unit_test_env): text_embedding = MistralAITextEmbedding(ai_model_id="overwrite-model") assert text_embedding.ai_model_id == "overwrite-model" + assert text_embedding.client._api_key == "test_api_key" -@pytest.mark.parametrize("exclude_list", [["MISTRALAI_CHAT_MODEL_ID"]], indirect=True) +@pytest.mark.parametrize("exclude_list", [["MISTRALAI_EMBEDDING_MODEL_ID"]], indirect=True) def test_embedding_missing_model(mistralai_unit_test_env): with pytest.raises(ServiceInitializationError): MistralAITextEmbedding( @@ -48,6 +60,23 @@ def test_embedding_missing_api_key(mistralai_unit_test_env): ) +@pytest.mark.parametrize("exclude_list", [["MISTRALAI_API_KEY", "MISTRALAI_EMBEDDING_MODEL_ID"]], indirect=True) +def test_embedding_missing_api_key_constructor(mistralai_unit_test_env): + with pytest.raises(ServiceInitializationError): + MistralAITextEmbedding( + env_file_path="test.env", + ) + + +@pytest.mark.parametrize("exclude_list", [["MISTRALAI_API_KEY", "MISTRALAI_EMBEDDING_MODEL_ID"]], indirect=True) +def test_embedding_missing_model_constructor(mistralai_unit_test_env): + with pytest.raises(ServiceInitializationError): + MistralAITextEmbedding( + api_key="test_api_key", + env_file_path="test.env", + ) + + @pytest.mark.asyncio async def test_embedding_generate_embedding(mistralai_unit_test_env): mock_client = AsyncMock(spec=MistralAsyncClient) From 4f48ca1cf7bef41de9017e5520006866489561fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nico=20M=C3=B6ller?= Date: Mon, 8 Jul 2024 09:11:18 +0200 Subject: [PATCH 3/8] added some more test cases for completion --- .../test_mistralai_chat_completion.py | 43 ++++++++++++++++--- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/python/tests/unit/connectors/mistral_ai/services/test_mistralai_chat_completion.py b/python/tests/unit/connectors/mistral_ai/services/test_mistralai_chat_completion.py index 8510fbae3ea5..28f53ddc2ad0 100644 --- a/python/tests/unit/connectors/mistral_ai/services/test_mistralai_chat_completion.py +++ b/python/tests/unit/connectors/mistral_ai/services/test_mistralai_chat_completion.py @@ -112,20 +112,53 @@ def test_mistral_ai_chat_completion_init(mistralai_unit_test_env) -> None: mistral_ai_chat_completion = MistralAIChatCompletion() assert mistral_ai_chat_completion.ai_model_id == mistralai_unit_test_env["MISTRALAI_CHAT_MODEL_ID"] + assert mistral_ai_chat_completion.async_client._api_key == mistralai_unit_test_env["MISTRALAI_API_KEY"] assert isinstance(mistral_ai_chat_completion, ChatCompletionClientBase) -@pytest.mark.parametrize("exclude_list", [["MISTRALAI_API_KEY"]], indirect=True) -def test_mistral_ai_chat_completion_init_with_empty_api_key(mistralai_unit_test_env) -> None: - ai_model_id = "test_model_id" +@pytest.mark.parametrize("exclude_list", [["MISTRALAI_API_KEY", "MISTRALAI_CHAT_MODEL_ID"]], indirect=True) +def test_mistral_ai_chat_completion_init_constructor(mistralai_unit_test_env) -> None: + # Test successful initialization + mistral_ai_chat_completion = MistralAIChatCompletion( + api_key="overwrite_api_key", + ai_model_id="overwrite_model_id", + env_file_path="test.env", + ) + + assert mistral_ai_chat_completion.ai_model_id == "overwrite_model_id" + assert mistral_ai_chat_completion.async_client._api_key == "overwrite_api_key" + assert isinstance(mistral_ai_chat_completion, ChatCompletionClientBase) + +@pytest.mark.parametrize("exclude_list", [["MISTRALAI_API_KEY", "MISTRALAI_CHAT_MODEL_ID"]], indirect=True) +def test_mistral_ai_chat_completion_init_constructor_missing_model(mistralai_unit_test_env) -> None: + # Test successful initialization with pytest.raises(ServiceInitializationError): MistralAIChatCompletion( - ai_model_id=ai_model_id, - env_file_path="test.env", + api_key="overwrite_api_key", + env_file_path="test.env" ) +@pytest.mark.parametrize("exclude_list", [["MISTRALAI_API_KEY", "MISTRALAI_CHAT_MODEL_ID"]], indirect=True) +def test_mistral_ai_chat_completion_init_constructor_missing_api_key(mistralai_unit_test_env) -> None: + # Test successful initialization + with pytest.raises(ServiceInitializationError): + MistralAIChatCompletion( + ai_model_id="overwrite_model_id", + env_file_path="test.env" + ) + + +def test_mistral_ai_chat_completion_init_hybrid(mistralai_unit_test_env) -> None: + mistral_ai_chat_completion = MistralAIChatCompletion( + ai_model_id="overwrite_model_id", + env_file_path="test.env", + ) + assert mistral_ai_chat_completion.ai_model_id == "overwrite_model_id" + assert mistral_ai_chat_completion.async_client._api_key == "test_api_key" + + @pytest.mark.parametrize("exclude_list", [["MISTRALAI_CHAT_MODEL_ID"]], indirect=True) def test_mistral_ai_chat_completion_init_with_empty_model_id(mistralai_unit_test_env) -> None: with pytest.raises(ServiceInitializationError): From df8a64e898f30f48321769aa3c75a4d389bdbf9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nico=20M=C3=B6ller?= Date: Mon, 22 Jul 2024 15:39:29 +0200 Subject: [PATCH 4/8] Integrated PR Feedback --- .../services/mistral_ai_text_embedding.py | 32 ++++++++++++++++--- .../test_mistralai_text_embeddings.py | 14 +++++++- 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py b/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py index f4bea0e6b40c..53f73d2fd982 100644 --- a/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py +++ b/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py @@ -1,7 +1,12 @@ # Copyright (c) Microsoft. All rights reserved. +import sys -from typing import Any +if sys.version_info >= (3, 12): + from typing import Any, override # pragma: no cover +else: + from typing_extensions import Any, override # pragma: no cover +import logging from mistralai.async_client import MistralAsyncClient from mistralai.models.embeddings import EmbeddingResponse @@ -10,9 +15,12 @@ from semantic_kernel.connectors.ai.embeddings.embedding_generator_base import EmbeddingGeneratorBase from semantic_kernel.connectors.ai.mistral_ai.settings.mistral_ai_settings import MistralAISettings +from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError, ServiceResponseException from semantic_kernel.utils.experimental_decorator import experimental_class +logger: logging.Logger = logging.getLogger(__name__) + @experimental_class class MistralAITextEmbedding(EmbeddingGeneratorBase): @@ -71,7 +79,23 @@ def __init__( client=client, ) - async def generate_embeddings(self, texts: list[str], **kwargs: Any) -> ndarray: + @override + async def generate_embeddings( + self, + texts: list[str], + settings: "PromptExecutionSettings | None" = None, + **kwargs: Any, + ) -> ndarray: + embedding_response: EmbeddingResponse = await self.generate_raw_embeddings(texts, settings, **kwargs) + return array([array(item.embedding) for item in embedding_response.data]) + + @override + async def generate_raw_embeddings( + self, + texts: list[str], + settings: "PromptExecutionSettings | None" = None, + **kwargs: Any, + ) -> "EmbeddingResponse": """Generate embeddings from the Mistral AI service.""" try: @@ -83,6 +107,6 @@ async def generate_embeddings(self, texts: list[str], **kwargs: Any) -> ndarray: raise ServiceResponseException( f"{type(self)} service failed to complete the embedding request.", ex, - ) from ex + ) from ex - return array([array(item.embedding) for item in response.data]) + return response diff --git a/python/tests/unit/connectors/mistral_ai/services/test_mistralai_text_embeddings.py b/python/tests/unit/connectors/mistral_ai/services/test_mistralai_text_embeddings.py index 25ee5203db90..a664caaa46fb 100644 --- a/python/tests/unit/connectors/mistral_ai/services/test_mistralai_text_embeddings.py +++ b/python/tests/unit/connectors/mistral_ai/services/test_mistralai_text_embeddings.py @@ -4,6 +4,7 @@ import pytest from mistralai.async_client import MistralAsyncClient +from mistralai.models.embeddings import EmbeddingResponse from semantic_kernel.connectors.ai.mistral_ai.services.mistral_ai_text_embedding import MistralAITextEmbedding from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError, ServiceResponseException @@ -77,10 +78,21 @@ def test_embedding_missing_model_constructor(mistralai_unit_test_env): ) +@pytest.mark.asyncio +async def test_embedding_generate_raw_embedding(mistralai_unit_test_env): + mock_client = AsyncMock(spec=MistralAsyncClient) + mock_embedding_response = MagicMock(spec=EmbeddingResponse, data=[MagicMock(embedding=[1, 2, 3, 4, 5])]) + mock_client.embeddings.return_value = mock_embedding_response + text_embedding = MistralAITextEmbedding(client=mock_client) + embedding = await text_embedding.generate_raw_embeddings(["test"]) + assert isinstance(embedding, EmbeddingResponse) + assert embedding.data[0].embedding == [1, 2, 3, 4, 5] + + @pytest.mark.asyncio async def test_embedding_generate_embedding(mistralai_unit_test_env): mock_client = AsyncMock(spec=MistralAsyncClient) - mock_embedding_response = MagicMock(data=[MagicMock(embedding=[1, 2, 3, 4, 5])]) + mock_embedding_response = MagicMock(spec=EmbeddingResponse, data=[MagicMock(embedding=[1, 2, 3, 4, 5])]) mock_client.embeddings.return_value = mock_embedding_response text_embedding = MistralAITextEmbedding(client=mock_client) embedding = await text_embedding.generate_embeddings(["test"]) From c1b35c8a1eb91f94256d5d24eaa5ef6c5c849049 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nico=20M=C3=B6ller?= Date: Mon, 22 Jul 2024 16:08:43 +0200 Subject: [PATCH 5/8] removed unencessary cast --- .../ai/mistral_ai/services/mistral_ai_text_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py b/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py index 53f73d2fd982..1acfbeee9e85 100644 --- a/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py +++ b/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py @@ -87,7 +87,7 @@ async def generate_embeddings( **kwargs: Any, ) -> ndarray: embedding_response: EmbeddingResponse = await self.generate_raw_embeddings(texts, settings, **kwargs) - return array([array(item.embedding) for item in embedding_response.data]) + return array([item.embedding for item in embedding_response.data]) @override async def generate_raw_embeddings( From fb3b8f2055b36e753b7813bf2dd1eb32feaac57c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nico=20M=C3=B6ller?= Date: Thu, 25 Jul 2024 11:33:55 +0200 Subject: [PATCH 6/8] addressed PR Feedback --- .../ai/mistral_ai/services/mistral_ai_text_embedding.py | 8 ++++---- .../mistral_ai/services/test_mistralai_text_embeddings.py | 3 +-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py b/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py index 1acfbeee9e85..9ec4a365a24e 100644 --- a/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py +++ b/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py @@ -87,7 +87,7 @@ async def generate_embeddings( **kwargs: Any, ) -> ndarray: embedding_response: EmbeddingResponse = await self.generate_raw_embeddings(texts, settings, **kwargs) - return array([item.embedding for item in embedding_response.data]) + return array(embedding_response) @override async def generate_raw_embeddings( @@ -95,11 +95,11 @@ async def generate_raw_embeddings( texts: list[str], settings: "PromptExecutionSettings | None" = None, **kwargs: Any, - ) -> "EmbeddingResponse": + ) -> "Any": """Generate embeddings from the Mistral AI service.""" try: - response: EmbeddingResponse = await self.client.embeddings( + embedding_response: EmbeddingResponse = await self.client.embeddings( model=self.ai_model_id, input=texts ) @@ -109,4 +109,4 @@ async def generate_raw_embeddings( ex, ) from ex - return response + return [item.embedding for item in embedding_response.data] diff --git a/python/tests/unit/connectors/mistral_ai/services/test_mistralai_text_embeddings.py b/python/tests/unit/connectors/mistral_ai/services/test_mistralai_text_embeddings.py index a664caaa46fb..5c3b2ef07ced 100644 --- a/python/tests/unit/connectors/mistral_ai/services/test_mistralai_text_embeddings.py +++ b/python/tests/unit/connectors/mistral_ai/services/test_mistralai_text_embeddings.py @@ -85,8 +85,7 @@ async def test_embedding_generate_raw_embedding(mistralai_unit_test_env): mock_client.embeddings.return_value = mock_embedding_response text_embedding = MistralAITextEmbedding(client=mock_client) embedding = await text_embedding.generate_raw_embeddings(["test"]) - assert isinstance(embedding, EmbeddingResponse) - assert embedding.data[0].embedding == [1, 2, 3, 4, 5] + assert embedding == [[1, 2, 3, 4, 5]] @pytest.mark.asyncio From b09db5a4da32ddcc54a7be4526a077bec637e6ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nico=20M=C3=B6ller?= Date: Thu, 1 Aug 2024 10:25:54 +0200 Subject: [PATCH 7/8] Update python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py Co-authored-by: Eduard van Valkenburg --- .../ai/mistral_ai/services/mistral_ai_text_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py b/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py index 9ec4a365a24e..7c44e2cfb7ce 100644 --- a/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py +++ b/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py @@ -95,7 +95,7 @@ async def generate_raw_embeddings( texts: list[str], settings: "PromptExecutionSettings | None" = None, **kwargs: Any, - ) -> "Any": + ) -> Any: """Generate embeddings from the Mistral AI service.""" try: From 0350506527927eea16cefe424b5955d9b9649a22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nico=20M=C3=B6ller?= Date: Mon, 5 Aug 2024 09:50:11 +0200 Subject: [PATCH 8/8] fixed response type of generate embeddings --- .../ai/mistral_ai/services/mistral_ai_text_embedding.py | 2 +- .../mistral_ai/services/test_mistralai_text_embeddings.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py b/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py index 7c44e2cfb7ce..24b2905b1587 100644 --- a/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py +++ b/python/semantic_kernel/connectors/ai/mistral_ai/services/mistral_ai_text_embedding.py @@ -86,7 +86,7 @@ async def generate_embeddings( settings: "PromptExecutionSettings | None" = None, **kwargs: Any, ) -> ndarray: - embedding_response: EmbeddingResponse = await self.generate_raw_embeddings(texts, settings, **kwargs) + embedding_response = await self.generate_raw_embeddings(texts, settings, **kwargs) return array(embedding_response) @override diff --git a/python/tests/unit/connectors/mistral_ai/services/test_mistralai_text_embeddings.py b/python/tests/unit/connectors/mistral_ai/services/test_mistralai_text_embeddings.py index 5c3b2ef07ced..98550ca6f1ad 100644 --- a/python/tests/unit/connectors/mistral_ai/services/test_mistralai_text_embeddings.py +++ b/python/tests/unit/connectors/mistral_ai/services/test_mistralai_text_embeddings.py @@ -45,6 +45,13 @@ def test_embedding_with_model(mistralai_unit_test_env): assert text_embedding.client._api_key == "test_api_key" +@pytest.mark.parametrize("exclude_list", [["MISTRALAI_EMBEDDING_MODEL_ID"]], indirect=True) +def test_embedding_with_model_without_env(mistralai_unit_test_env): + text_embedding = MistralAITextEmbedding(ai_model_id="overwrite-model") + assert text_embedding.ai_model_id == "overwrite-model" + assert text_embedding.client._api_key == "test_api_key" + + @pytest.mark.parametrize("exclude_list", [["MISTRALAI_EMBEDDING_MODEL_ID"]], indirect=True) def test_embedding_missing_model(mistralai_unit_test_env): with pytest.raises(ServiceInitializationError):