From d4085d2ee2e2567b020ffc91777505bee54247fb Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 30 Jul 2024 10:41:59 -0700 Subject: [PATCH 01/13] Google AI function calling, next streaming --- .../google_ai_prompt_execution_settings.py | 4 +- .../services/google_ai_chat_completion.py | 113 ++++++++++++++++- .../ai/google/google_ai/services/utils.py | 116 +++++++++++++++++- .../contents/function_result_content.py | 4 +- .../completions/test_chat_completions.py | 15 +++ 5 files changed, 242 insertions(+), 10 deletions(-) diff --git a/python/semantic_kernel/connectors/ai/google/google_ai/google_ai_prompt_execution_settings.py b/python/semantic_kernel/connectors/ai/google/google_ai/google_ai_prompt_execution_settings.py index 94c0cc9c17cf..a5eff4dc2b98 100644 --- a/python/semantic_kernel/connectors/ai/google/google_ai/google_ai_prompt_execution_settings.py +++ b/python/semantic_kernel/connectors/ai/google/google_ai/google_ai_prompt_execution_settings.py @@ -36,7 +36,7 @@ class GoogleAIChatPromptExecutionSettings(GoogleAIPromptExecutionSettings): """Google AI Chat Prompt Execution Settings.""" tools: list[dict[str, Any]] | None = Field(None, max_length=64) - tool_choice: str | None = None + tool_config: dict[str, Any] | None = None @override def prepare_settings_dict(self, **kwargs) -> dict[str, Any]: @@ -47,7 +47,7 @@ def prepare_settings_dict(self, **kwargs) -> dict[str, Any]: """ settings_dict = super().prepare_settings_dict(**kwargs) settings_dict.pop("tools", None) - settings_dict.pop("tool_choice", None) + settings_dict.pop("tool_config", None) return settings_dict diff --git a/python/semantic_kernel/connectors/ai/google/google_ai/services/google_ai_chat_completion.py b/python/semantic_kernel/connectors/ai/google/google_ai/services/google_ai_chat_completion.py index 8f928e05059a..31e452ad5614 100644 --- a/python/semantic_kernel/connectors/ai/google/google_ai/services/google_ai_chat_completion.py +++ b/python/semantic_kernel/connectors/ai/google/google_ai/services/google_ai_chat_completion.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft. All rights reserved. +import asyncio +import logging import sys from collections.abc import AsyncGenerator from typing import TYPE_CHECKING, Any @@ -11,6 +13,7 @@ from google.generativeai.types import AsyncGenerateContentResponse, GenerateContentResponse, GenerationConfig from pydantic import ValidationError +from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior from semantic_kernel.connectors.ai.google.google_ai.google_ai_prompt_execution_settings import ( GoogleAIChatPromptExecutionSettings, ) @@ -19,11 +22,18 @@ filter_system_message, finish_reason_from_google_ai_to_semantic_kernel, format_assistant_message, + format_gemini_function_name_to_kernel_function_fully_qualified_name, + format_tool_message, format_user_message, + update_settings_from_function_choice_configuration, ) +from semantic_kernel.contents.function_call_content import FunctionCallContent from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent +from semantic_kernel.contents.text_content import TextContent from semantic_kernel.contents.utils.author_role import AuthorRole from semantic_kernel.contents.utils.finish_reason import FinishReason +from semantic_kernel.functions.kernel_arguments import KernelArguments +from semantic_kernel.kernel import Kernel if sys.version_info >= (3, 12): from typing import override # pragma: no cover @@ -33,12 +43,17 @@ from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase from semantic_kernel.connectors.ai.google.google_ai.google_ai_settings import GoogleAISettings from semantic_kernel.contents.chat_history import ChatHistory -from semantic_kernel.contents.chat_message_content import ChatMessageContent -from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError +from semantic_kernel.contents.chat_message_content import ITEM_TYPES, ChatMessageContent +from semantic_kernel.exceptions.service_exceptions import ( + ServiceInitializationError, + ServiceInvalidExecutionSettingsError, +) if TYPE_CHECKING: from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings +logger: logging.Logger = logging.getLogger(__name__) + class GoogleAIChatCompletion(GoogleAIBase, ChatCompletionClientBase): """Google AI Chat Completion Client.""" @@ -97,7 +112,40 @@ async def get_chat_message_contents( settings = self.get_prompt_execution_settings_from_settings(settings) assert isinstance(settings, GoogleAIChatPromptExecutionSettings) # nosec - return await self._send_chat_request(chat_history, settings) + if ( + settings.function_choice_behavior is None + or not settings.function_choice_behavior.auto_invoke_kernel_functions + ): + return await self._send_chat_request(chat_history, settings) + + kernel = kwargs.get("kernel") + if not isinstance(kernel, Kernel): + raise ServiceInvalidExecutionSettingsError("Kernel is required for auto invoking functions.") + + self._configure_function_choice_behavior(settings, kernel) + + for request_index in range(settings.function_choice_behavior.maximum_auto_invoke_attempts): + completions = await self._send_chat_request(chat_history, settings) + chat_history.add_message(message=completions[0]) + function_calls = [item for item in chat_history.messages[-1].items if isinstance(item, FunctionCallContent)] + if (fc_count := len(function_calls)) == 0: + return completions + + results = await self._invoke_function_calls( + function_calls=function_calls, + chat_history=chat_history, + kernel=kernel, + arguments=kwargs.get("argument", None), + function_call_count=fc_count, + request_index=request_index, + function_behavior=settings.function_choice_behavior, + ) + + if any(result.terminate for result in results if result is not None): + return completions + else: + # do a final call without auto function calling + return await self._send_chat_request(chat_history, settings) async def _send_chat_request( self, chat_history: ChatHistory, settings: GoogleAIChatPromptExecutionSettings @@ -112,6 +160,8 @@ async def _send_chat_request( response: AsyncGenerateContentResponse = await model.generate_content_async( contents=self._prepare_chat_history_for_request(chat_history), generation_config=GenerationConfig(**settings.prepare_settings_dict()), + tools=settings.tools, + tool_config=settings.tool_config, ) return [self._create_chat_message_content(response, candidate) for candidate in response.candidates] @@ -133,10 +183,25 @@ def _create_chat_message_content( response_metadata = self._get_metadata_from_response(response) response_metadata.update(self._get_metadata_from_candidate(candidate)) + items: list[ITEM_TYPES] = [] + for idx, part in enumerate(candidate.content.parts): + if part.text: + items.append(TextContent(text=part.text, inner_content=response, metadata=response_metadata)) + elif part.function_call: + items.append( + FunctionCallContent( + id=f"{part.function_call.name}_{idx!s}", + name=format_gemini_function_name_to_kernel_function_fully_qualified_name( + part.function_call.name + ), + arguments={k: v for k, v in part.function_call.args.items()}, + ) + ) + return ChatMessageContent( ai_model_id=self.ai_model_id, role=AuthorRole.ASSISTANT, - content=candidate.content.parts[0].text, + items=items, inner_content=response, finish_reason=finish_reason, metadata=response_metadata, @@ -230,6 +295,8 @@ def _prepare_chat_history_for_request( chat_request_messages.append(Content(role="user", parts=format_user_message(message))) elif message.role == AuthorRole.ASSISTANT: chat_request_messages.append(Content(role="model", parts=format_assistant_message(message))) + elif message.role == AuthorRole.TOOL: + chat_request_messages.append(Content(role="function", parts=format_tool_message(message))) else: raise ValueError(f"Unsupported role: {message.role}") @@ -267,6 +334,44 @@ def _get_metadata_from_candidate(self, candidate: Candidate) -> dict[str, Any]: "token_count": candidate.token_count, } + def _configure_function_choice_behavior(self, settings: GoogleAIChatPromptExecutionSettings, kernel: Kernel): + """Configure the function choice behavior to include the kernel functions.""" + if not settings.function_choice_behavior: + raise ServiceInvalidExecutionSettingsError("Function choice behavior is required for tool calls.") + + settings.function_choice_behavior.configure( + kernel=kernel, + update_settings_callback=update_settings_from_function_choice_configuration, + settings=settings, + ) + + async def _invoke_function_calls( + self, + function_calls: list[FunctionCallContent], + chat_history: ChatHistory, + kernel: Kernel, + arguments: KernelArguments | None, + function_call_count: int, + request_index: int, + function_behavior: FunctionChoiceBehavior, + ): + """Invoke function calls.""" + logger.info(f"processing {function_call_count} tool calls in parallel.") + + return await asyncio.gather( + *[ + kernel.invoke_function_call( + function_call=function_call, + chat_history=chat_history, + arguments=arguments, + function_call_count=function_call_count, + request_index=request_index, + function_behavior=function_behavior, + ) + for function_call in function_calls + ], + ) + @override def get_prompt_execution_settings_class( self, diff --git a/python/semantic_kernel/connectors/ai/google/google_ai/services/utils.py b/python/semantic_kernel/connectors/ai/google/google_ai/services/utils.py index b1eb6aa1bc57..2ab9d4fe4032 100644 --- a/python/semantic_kernel/connectors/ai/google/google_ai/services/utils.py +++ b/python/semantic_kernel/connectors/ai/google/google_ai/services/utils.py @@ -1,16 +1,24 @@ # Copyright (c) Microsoft. All rights reserved. +import json import logging +from typing import Any -from google.generativeai.protos import Blob, Candidate, Part +from google.generativeai.protos import Blob, Candidate, FunctionResponse, Part +from semantic_kernel.connectors.ai.function_choice_behavior import FunctionCallChoiceConfiguration, FunctionChoiceType +from semantic_kernel.connectors.ai.google.google_ai.google_ai_prompt_execution_settings import ( + GoogleAIChatPromptExecutionSettings, +) from semantic_kernel.contents.chat_history import ChatHistory from semantic_kernel.contents.chat_message_content import ChatMessageContent +from semantic_kernel.contents.function_result_content import FunctionResultContent from semantic_kernel.contents.image_content import ImageContent from semantic_kernel.contents.text_content import TextContent from semantic_kernel.contents.utils.author_role import AuthorRole from semantic_kernel.contents.utils.finish_reason import FinishReason as SemanticKernelFinishReason from semantic_kernel.exceptions.service_exceptions import ServiceInvalidRequestError +from semantic_kernel.functions.kernel_function_metadata import KernelFunctionMetadata logger: logging.Logger = logging.getLogger(__name__) @@ -96,3 +104,109 @@ def format_assistant_message(message: ChatMessageContent) -> list[Part]: The formatted assistant message as a list of parts. """ return [Part(text=message.content)] + + +def format_tool_message(message: ChatMessageContent) -> list[Part]: + """Format a tool message to the expected object for the client. + + Args: + message: The tool message. + + Returns: + The formatted tool message. + """ + if len(message.items) != 1: + logger.warning( + "Unsupported number of items in Tool message while formatting chat history for Google AI: " + f"{len(message.items)}" + ) + + if not isinstance(message.items[0], FunctionResultContent): + raise ValueError("No FunctionResultContent found in the message items") + + gemini_function_name = format_function_result_content_name_to_gemini_function_name(message.items[0]) + + return [ + Part( + function_response=FunctionResponse( + name=gemini_function_name, + response={ + "name": gemini_function_name, + "content": json.dumps(message.items[0].result), + }, + ) + ) + ] + + +_FUNCTION_CHOICE_TYPE_TO_GOOGLE_FUNCTION_CALLING_MODE = { + FunctionChoiceType.AUTO: "AUTO", + FunctionChoiceType.NONE: "NONE", + FunctionChoiceType.REQUIRED: "ANY", +} + +# The separator used in the fully qualified name of the function instead of the default "-" separator. +# This is required since Gemini doesn't work well with "-" in the function name. +# https://ai.google.dev/gemini-api/docs/function-calling#function_declarations +GEMINI_FUNCTION_NAME_SEPARATOR = "_" + + +def format_function_result_content_name_to_gemini_function_name(function_result_content: FunctionResultContent) -> str: + """Format the function result content name to the Gemini function name.""" + return ( + f"{function_result_content.plugin_name}{GEMINI_FUNCTION_NAME_SEPARATOR}{function_result_content.function_name}" + if function_result_content.plugin_name + else function_result_content.function_name + ) + + +def format_kernel_function_fully_qualified_name_to_gemini_function_name(metadata: KernelFunctionMetadata) -> str: + """Format the kernel function fully qualified name to the Gemini function name.""" + return ( + f"{metadata.plugin_name}{GEMINI_FUNCTION_NAME_SEPARATOR}{metadata.name}" + if metadata.plugin_name + else metadata.name + ) + + +def format_gemini_function_name_to_kernel_function_fully_qualified_name(gemini_function_name: str) -> str: + """Format the Gemini function name to the kernel function fully qualified name.""" + if GEMINI_FUNCTION_NAME_SEPARATOR in gemini_function_name: + plugin_name, function_name = gemini_function_name.split(GEMINI_FUNCTION_NAME_SEPARATOR, 1) + return f"{plugin_name}-{function_name}" + return gemini_function_name + + +def kernel_function_metadata_to_google_function_call_format(metadata: KernelFunctionMetadata) -> dict[str, Any]: + """Convert the kernel function metadata to function calling format.""" + return { + "name": format_kernel_function_fully_qualified_name_to_gemini_function_name(metadata), + "description": metadata.description or "", + "parameters": { + "type": "object", + "properties": {param.name: param.schema_data for param in metadata.parameters}, + "required": [p.name for p in metadata.parameters if p.is_required], + }, + } + + +def update_settings_from_function_choice_configuration( + function_choice_configuration: FunctionCallChoiceConfiguration, + settings: GoogleAIChatPromptExecutionSettings, + type: FunctionChoiceType, +) -> None: + """Update the settings from a FunctionChoiceConfiguration.""" + if function_choice_configuration.available_functions: + settings.tool_config = { + "function_calling_config": { + "mode": _FUNCTION_CHOICE_TYPE_TO_GOOGLE_FUNCTION_CALLING_MODE[type], + } + } + settings.tools = [ + { + "function_declarations": [ + kernel_function_metadata_to_google_function_call_format(f) + for f in function_choice_configuration.available_functions + ] + } + ] diff --git a/python/semantic_kernel/contents/function_result_content.py b/python/semantic_kernel/contents/function_result_content.py index 4da3162936ac..edb471d9620b 100644 --- a/python/semantic_kernel/contents/function_result_content.py +++ b/python/semantic_kernel/contents/function_result_content.py @@ -150,12 +150,10 @@ def from_function_call_content_and_result( metadata=metadata, ) - def to_chat_message_content(self, unwrap: bool = False) -> "ChatMessageContent": + def to_chat_message_content(self) -> "ChatMessageContent": """Convert the instance to a ChatMessageContent.""" from semantic_kernel.contents.chat_message_content import ChatMessageContent - if unwrap and isinstance(self.result, str): - return ChatMessageContent(role=AuthorRole.TOOL, content=self.result) return ChatMessageContent(role=AuthorRole.TOOL, items=[self]) def to_dict(self) -> dict[str, str]: diff --git a/python/tests/integration/completions/test_chat_completions.py b/python/tests/integration/completions/test_chat_completions.py index 9f16b22bd1a3..886fbbb74d8c 100644 --- a/python/tests/integration/completions/test_chat_completions.py +++ b/python/tests/integration/completions/test_chat_completions.py @@ -495,6 +495,20 @@ def services() -> dict[str, tuple[ChatCompletionClientBase, type[PromptExecution ["house", "germany"], id="google_ai_image_input_file", ), + pytest.param( + "google_ai", + { + "function_choice_behavior": FunctionChoiceBehavior.Auto( + auto_invoke=True, filters={"excluded_plugins": ["chat"]} + ), + "max_tokens": 256, + }, + [ + ChatMessageContent(role=AuthorRole.USER, items=[TextContent(text="What is 3+345?")]), + ], + ["348"], + id="google_ai_tool_call_auto", + ), ], ) @@ -560,6 +574,7 @@ async def execute_invoke(kernel: Kernel, history: ChatHistory, output: str, stre response = invocation.value[0] print(response) if isinstance(response, ChatMessageContent): + assert response.items, "No items in response" for item in response.items: if isinstance(item, TextContent): assert item.text is not None From e06e20776bf074453ae38a52753a5c82e4a2e3ba Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Wed, 31 Jul 2024 15:44:14 -0700 Subject: [PATCH 02/13] Google AI function calling done --- .../services/google_ai_chat_completion.py | 88 ++++++++++++++++++- 1 file changed, 86 insertions(+), 2 deletions(-) diff --git a/python/semantic_kernel/connectors/ai/google/google_ai/services/google_ai_chat_completion.py b/python/semantic_kernel/connectors/ai/google/google_ai/services/google_ai_chat_completion.py index 0c072367b301..9d739d7375d2 100644 --- a/python/semantic_kernel/connectors/ai/google/google_ai/services/google_ai_chat_completion.py +++ b/python/semantic_kernel/connectors/ai/google/google_ai/services/google_ai_chat_completion.py @@ -5,6 +5,7 @@ import logging import sys from collections.abc import AsyncGenerator +from functools import reduce from typing import TYPE_CHECKING, Any import google.generativeai as genai @@ -28,7 +29,9 @@ ) from semantic_kernel.connectors.ai.google.shared_utils import filter_system_message from semantic_kernel.contents.function_call_content import FunctionCallContent +from semantic_kernel.contents.streaming_chat_message_content import ITEM_TYPES as STREAMING_ITEM_TYPES from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent +from semantic_kernel.contents.streaming_text_content import StreamingTextContent from semantic_kernel.contents.text_content import TextContent from semantic_kernel.contents.utils.author_role import AuthorRole from semantic_kernel.contents.utils.finish_reason import FinishReason @@ -220,11 +223,68 @@ async def get_streaming_chat_message_contents( settings = self.get_prompt_execution_settings_from_settings(settings) assert isinstance(settings, GoogleAIChatPromptExecutionSettings) # nosec - async_generator = self._send_chat_streaming_request(chat_history, settings) + if ( + settings.function_choice_behavior is None + or not settings.function_choice_behavior.auto_invoke_kernel_functions + ): + # No auto invoke is required. + async_generator = self._send_chat_streaming_request(chat_history, settings) + else: + # Auto invoke is required. + async_generator = self._get_streaming_chat_message_contents_auto_invoke(chat_history, settings, **kwargs) async for messages in async_generator: yield messages + async def _get_streaming_chat_message_contents_auto_invoke( + self, + chat_history: ChatHistory, + settings: GoogleAIChatPromptExecutionSettings, + **kwargs: Any, + ) -> AsyncGenerator[list[StreamingChatMessageContent], Any]: + """Get streaming chat message contents from the Google AI service with auto invoking functions.""" + kernel = kwargs.get("kernel") + if not isinstance(kernel, Kernel): + raise ServiceInvalidExecutionSettingsError("Kernel is required for auto invoking functions.") + if not settings.function_choice_behavior: + raise ServiceInvalidExecutionSettingsError( + "Function choice behavior is required for auto invoking functions." + ) + + self._configure_function_choice_behavior(settings, kernel) + + for request_index in range(settings.function_choice_behavior.maximum_auto_invoke_attempts): + all_messages: list[StreamingChatMessageContent] = [] + function_call_returned = False + async for messages in self._send_chat_streaming_request(chat_history, settings): + for message in messages: + if message: + all_messages.append(message) + if any(isinstance(item, FunctionCallContent) for item in message.items): + function_call_returned = True + yield messages + + if not function_call_returned: + # Response doesn't contain any function calls. No need to proceed to the next request. + return + + full_completion: StreamingChatMessageContent = reduce(lambda x, y: x + y, all_messages) + function_calls = [item for item in full_completion.items if isinstance(item, FunctionCallContent)] + chat_history.add_message(message=full_completion) + + results = await self._invoke_function_calls( + function_calls=function_calls, + chat_history=chat_history, + kernel=kernel, + arguments=kwargs.get("argument", None), + function_call_count=len(function_calls), + request_index=request_index, + function_behavior=settings.function_choice_behavior, + ) + + if any(result.terminate for result in results if result is not None): + return + async def _send_chat_streaming_request( self, chat_history: ChatHistory, @@ -240,6 +300,8 @@ async def _send_chat_streaming_request( response: AsyncGenerateContentResponse = await model.generate_content_async( contents=self._prepare_chat_history_for_request(chat_history), generation_config=GenerationConfig(**settings.prepare_settings_dict()), + tools=settings.tools, + tool_config=settings.tool_config, stream=True, ) @@ -265,11 +327,33 @@ def _create_streaming_chat_message_content( response_metadata = self._get_metadata_from_response(chunk) response_metadata.update(self._get_metadata_from_candidate(candidate)) + items: list[STREAMING_ITEM_TYPES] = [] + for idx, part in enumerate(candidate.content.parts): + if part.text: + items.append( + StreamingTextContent( + choice_index=candidate.index, + text=part.text, + inner_content=chunk, + metadata=response_metadata, + ) + ) + elif part.function_call: + items.append( + FunctionCallContent( + id=f"{part.function_call.name}_{idx!s}", + name=format_gemini_function_name_to_kernel_function_fully_qualified_name( + part.function_call.name + ), + arguments={k: v for k, v in part.function_call.args.items()}, + ) + ) + return StreamingChatMessageContent( ai_model_id=self.ai_model_id, role=AuthorRole.ASSISTANT, choice_index=candidate.index, - content=candidate.content.parts[0].text, + items=items, inner_content=chunk, finish_reason=finish_reason, metadata=response_metadata, From 71f8c27bcf7610f10d16e3d3d661a9628df85814 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Wed, 31 Jul 2024 22:50:38 -0700 Subject: [PATCH 03/13] Vertex AI function calling done; next: more tests --- .../services/google_ai_chat_completion.py | 57 ++---- .../ai/google/google_ai/services/utils.py | 49 +----- .../connectors/ai/google/shared_utils.py | 100 ++++++++++- .../ai/google/vertex_ai/services/utils.py | 129 +++++++++++++- .../services/vertex_ai_chat_completion.py | 165 +++++++++++++++++- .../vertex_ai_prompt_execution_settings.py | 9 +- .../completions/test_chat_completions.py | 14 ++ 7 files changed, 419 insertions(+), 104 deletions(-) diff --git a/python/semantic_kernel/connectors/ai/google/google_ai/services/google_ai_chat_completion.py b/python/semantic_kernel/connectors/ai/google/google_ai/services/google_ai_chat_completion.py index 9d739d7375d2..2e8463b9f6f6 100644 --- a/python/semantic_kernel/connectors/ai/google/google_ai/services/google_ai_chat_completion.py +++ b/python/semantic_kernel/connectors/ai/google/google_ai/services/google_ai_chat_completion.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. -import asyncio import logging import sys from collections.abc import AsyncGenerator @@ -14,7 +13,6 @@ from google.generativeai.types import AsyncGenerateContentResponse, GenerateContentResponse, GenerationConfig from pydantic import ValidationError -from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior from semantic_kernel.connectors.ai.google.google_ai.google_ai_prompt_execution_settings import ( GoogleAIChatPromptExecutionSettings, ) @@ -22,12 +20,16 @@ from semantic_kernel.connectors.ai.google.google_ai.services.utils import ( finish_reason_from_google_ai_to_semantic_kernel, format_assistant_message, - format_gemini_function_name_to_kernel_function_fully_qualified_name, format_tool_message, format_user_message, update_settings_from_function_choice_configuration, ) -from semantic_kernel.connectors.ai.google.shared_utils import filter_system_message +from semantic_kernel.connectors.ai.google.shared_utils import ( + configure_function_choice_behavior, + filter_system_message, + format_gemini_function_name_to_kernel_function_fully_qualified_name, + invoke_function_calls, +) from semantic_kernel.contents.function_call_content import FunctionCallContent from semantic_kernel.contents.streaming_chat_message_content import ITEM_TYPES as STREAMING_ITEM_TYPES from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent @@ -35,7 +37,6 @@ from semantic_kernel.contents.text_content import TextContent from semantic_kernel.contents.utils.author_role import AuthorRole from semantic_kernel.contents.utils.finish_reason import FinishReason -from semantic_kernel.functions.kernel_arguments import KernelArguments from semantic_kernel.kernel import Kernel if sys.version_info >= (3, 12): @@ -125,7 +126,7 @@ async def get_chat_message_contents( if not isinstance(kernel, Kernel): raise ServiceInvalidExecutionSettingsError("Kernel is required for auto invoking functions.") - self._configure_function_choice_behavior(settings, kernel) + configure_function_choice_behavior(settings, kernel, update_settings_from_function_choice_configuration) for request_index in range(settings.function_choice_behavior.maximum_auto_invoke_attempts): completions = await self._send_chat_request(chat_history, settings) @@ -134,7 +135,7 @@ async def get_chat_message_contents( if (fc_count := len(function_calls)) == 0: return completions - results = await self._invoke_function_calls( + results = await invoke_function_calls( function_calls=function_calls, chat_history=chat_history, kernel=kernel, @@ -251,7 +252,7 @@ async def _get_streaming_chat_message_contents_auto_invoke( "Function choice behavior is required for auto invoking functions." ) - self._configure_function_choice_behavior(settings, kernel) + configure_function_choice_behavior(settings, kernel, update_settings_from_function_choice_configuration) for request_index in range(settings.function_choice_behavior.maximum_auto_invoke_attempts): all_messages: list[StreamingChatMessageContent] = [] @@ -272,7 +273,7 @@ async def _get_streaming_chat_message_contents_auto_invoke( function_calls = [item for item in full_completion.items if isinstance(item, FunctionCallContent)] chat_history.add_message(message=full_completion) - results = await self._invoke_function_calls( + results = await invoke_function_calls( function_calls=function_calls, chat_history=chat_history, kernel=kernel, @@ -418,44 +419,6 @@ def _get_metadata_from_candidate(self, candidate: Candidate) -> dict[str, Any]: "token_count": candidate.token_count, } - def _configure_function_choice_behavior(self, settings: GoogleAIChatPromptExecutionSettings, kernel: Kernel): - """Configure the function choice behavior to include the kernel functions.""" - if not settings.function_choice_behavior: - raise ServiceInvalidExecutionSettingsError("Function choice behavior is required for tool calls.") - - settings.function_choice_behavior.configure( - kernel=kernel, - update_settings_callback=update_settings_from_function_choice_configuration, - settings=settings, - ) - - async def _invoke_function_calls( - self, - function_calls: list[FunctionCallContent], - chat_history: ChatHistory, - kernel: Kernel, - arguments: KernelArguments | None, - function_call_count: int, - request_index: int, - function_behavior: FunctionChoiceBehavior, - ): - """Invoke function calls.""" - logger.info(f"processing {function_call_count} tool calls in parallel.") - - return await asyncio.gather( - *[ - kernel.invoke_function_call( - function_call=function_call, - chat_history=chat_history, - arguments=arguments, - function_call_count=function_call_count, - request_index=request_index, - function_behavior=function_behavior, - ) - for function_call in function_calls - ], - ) - @override def get_prompt_execution_settings_class( self, diff --git a/python/semantic_kernel/connectors/ai/google/google_ai/services/utils.py b/python/semantic_kernel/connectors/ai/google/google_ai/services/utils.py index fcb34c2ed6c0..2570c82e0a3f 100644 --- a/python/semantic_kernel/connectors/ai/google/google_ai/services/utils.py +++ b/python/semantic_kernel/connectors/ai/google/google_ai/services/utils.py @@ -10,6 +10,11 @@ from semantic_kernel.connectors.ai.google.google_ai.google_ai_prompt_execution_settings import ( GoogleAIChatPromptExecutionSettings, ) +from semantic_kernel.connectors.ai.google.shared_utils import ( + FUNCTION_CHOICE_TYPE_TO_GOOGLE_FUNCTION_CALLING_MODE, + format_function_result_content_name_to_gemini_function_name, + format_kernel_function_fully_qualified_name_to_gemini_function_name, +) from semantic_kernel.contents.chat_message_content import ChatMessageContent from semantic_kernel.contents.function_result_content import FunctionResultContent from semantic_kernel.contents.image_content import ImageContent @@ -119,45 +124,7 @@ def format_tool_message(message: ChatMessageContent) -> list[Part]: ] -_FUNCTION_CHOICE_TYPE_TO_GOOGLE_FUNCTION_CALLING_MODE = { - FunctionChoiceType.AUTO: "AUTO", - FunctionChoiceType.NONE: "NONE", - FunctionChoiceType.REQUIRED: "ANY", -} - -# The separator used in the fully qualified name of the function instead of the default "-" separator. -# This is required since Gemini doesn't work well with "-" in the function name. -# https://ai.google.dev/gemini-api/docs/function-calling#function_declarations -GEMINI_FUNCTION_NAME_SEPARATOR = "_" - - -def format_function_result_content_name_to_gemini_function_name(function_result_content: FunctionResultContent) -> str: - """Format the function result content name to the Gemini function name.""" - return ( - f"{function_result_content.plugin_name}{GEMINI_FUNCTION_NAME_SEPARATOR}{function_result_content.function_name}" - if function_result_content.plugin_name - else function_result_content.function_name - ) - - -def format_kernel_function_fully_qualified_name_to_gemini_function_name(metadata: KernelFunctionMetadata) -> str: - """Format the kernel function fully qualified name to the Gemini function name.""" - return ( - f"{metadata.plugin_name}{GEMINI_FUNCTION_NAME_SEPARATOR}{metadata.name}" - if metadata.plugin_name - else metadata.name - ) - - -def format_gemini_function_name_to_kernel_function_fully_qualified_name(gemini_function_name: str) -> str: - """Format the Gemini function name to the kernel function fully qualified name.""" - if GEMINI_FUNCTION_NAME_SEPARATOR in gemini_function_name: - plugin_name, function_name = gemini_function_name.split(GEMINI_FUNCTION_NAME_SEPARATOR, 1) - return f"{plugin_name}-{function_name}" - return gemini_function_name - - -def kernel_function_metadata_to_google_function_call_format(metadata: KernelFunctionMetadata) -> dict[str, Any]: +def kernel_function_metadata_to_google_ai_function_call_format(metadata: KernelFunctionMetadata) -> dict[str, Any]: """Convert the kernel function metadata to function calling format.""" return { "name": format_kernel_function_fully_qualified_name_to_gemini_function_name(metadata), @@ -179,13 +146,13 @@ def update_settings_from_function_choice_configuration( if function_choice_configuration.available_functions: settings.tool_config = { "function_calling_config": { - "mode": _FUNCTION_CHOICE_TYPE_TO_GOOGLE_FUNCTION_CALLING_MODE[type], + "mode": FUNCTION_CHOICE_TYPE_TO_GOOGLE_FUNCTION_CALLING_MODE[type], } } settings.tools = [ { "function_declarations": [ - kernel_function_metadata_to_google_function_call_format(f) + kernel_function_metadata_to_google_ai_function_call_format(f) for f in function_choice_configuration.available_functions ] } diff --git a/python/semantic_kernel/connectors/ai/google/shared_utils.py b/python/semantic_kernel/connectors/ai/google/shared_utils.py index e898c9a8f3e1..832a86970391 100644 --- a/python/semantic_kernel/connectors/ai/google/shared_utils.py +++ b/python/semantic_kernel/connectors/ai/google/shared_utils.py @@ -1,8 +1,29 @@ # Copyright (c) Microsoft. All rights reserved. +import asyncio +import logging +from collections.abc import Callable + +from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior, FunctionChoiceType +from semantic_kernel.connectors.ai.google.google_ai.google_ai_prompt_execution_settings import ( + GoogleAIChatPromptExecutionSettings, +) +from semantic_kernel.connectors.ai.google.vertex_ai.vertex_ai_prompt_execution_settings import ( + VertexAIChatPromptExecutionSettings, +) from semantic_kernel.contents.chat_history import ChatHistory +from semantic_kernel.contents.function_call_content import FunctionCallContent +from semantic_kernel.contents.function_result_content import FunctionResultContent from semantic_kernel.contents.utils.author_role import AuthorRole -from semantic_kernel.exceptions.service_exceptions import ServiceInvalidRequestError +from semantic_kernel.exceptions.service_exceptions import ( + ServiceInvalidExecutionSettingsError, + ServiceInvalidRequestError, +) +from semantic_kernel.functions.kernel_arguments import KernelArguments +from semantic_kernel.functions.kernel_function_metadata import KernelFunctionMetadata +from semantic_kernel.kernel import Kernel + +logger: logging.Logger = logging.getLogger(__name__) def filter_system_message(chat_history: ChatHistory) -> str | None: @@ -21,3 +42,80 @@ def filter_system_message(chat_history: ChatHistory) -> str | None: return message.content return None + + +async def invoke_function_calls( + function_calls: list[FunctionCallContent], + chat_history: ChatHistory, + kernel: Kernel, + arguments: KernelArguments | None, + function_call_count: int, + request_index: int, + function_behavior: FunctionChoiceBehavior, +): + """Invoke function calls.""" + logger.info(f"processing {function_call_count} tool calls in parallel.") + + return await asyncio.gather( + *[ + kernel.invoke_function_call( + function_call=function_call, + chat_history=chat_history, + arguments=arguments, + function_call_count=function_call_count, + request_index=request_index, + function_behavior=function_behavior, + ) + for function_call in function_calls + ], + ) + + +FUNCTION_CHOICE_TYPE_TO_GOOGLE_FUNCTION_CALLING_MODE = { + FunctionChoiceType.AUTO: "AUTO", + FunctionChoiceType.NONE: "NONE", + FunctionChoiceType.REQUIRED: "ANY", +} + +# The separator used in the fully qualified name of the function instead of the default "-" separator. +# This is required since Gemini doesn't work well with "-" in the function name. +# https://ai.google.dev/gemini-api/docs/function-calling#function_declarations +GEMINI_FUNCTION_NAME_SEPARATOR = "_" + + +def format_function_result_content_name_to_gemini_function_name(function_result_content: FunctionResultContent) -> str: + """Format the function result content name to the Gemini function name.""" + return ( + f"{function_result_content.plugin_name}{GEMINI_FUNCTION_NAME_SEPARATOR}{function_result_content.function_name}" + if function_result_content.plugin_name + else function_result_content.function_name + ) + + +def format_kernel_function_fully_qualified_name_to_gemini_function_name(metadata: KernelFunctionMetadata) -> str: + """Format the kernel function fully qualified name to the Gemini function name.""" + return ( + f"{metadata.plugin_name}{GEMINI_FUNCTION_NAME_SEPARATOR}{metadata.name}" + if metadata.plugin_name + else metadata.name + ) + + +def format_gemini_function_name_to_kernel_function_fully_qualified_name(gemini_function_name: str) -> str: + """Format the Gemini function name to the kernel function fully qualified name.""" + if GEMINI_FUNCTION_NAME_SEPARATOR in gemini_function_name: + plugin_name, function_name = gemini_function_name.split(GEMINI_FUNCTION_NAME_SEPARATOR, 1) + return f"{plugin_name}-{function_name}" + return gemini_function_name + + +def configure_function_choice_behavior( + settings: GoogleAIChatPromptExecutionSettings | VertexAIChatPromptExecutionSettings, + kernel: Kernel, + callback: Callable[..., None], +): + """Configure the function choice behavior to include the kernel functions.""" + if not settings.function_choice_behavior: + raise ServiceInvalidExecutionSettingsError("Function choice behavior is required for tool calls.") + + settings.function_choice_behavior.configure(kernel=kernel, update_settings_callback=callback, settings=settings) diff --git a/python/semantic_kernel/connectors/ai/google/vertex_ai/services/utils.py b/python/semantic_kernel/connectors/ai/google/vertex_ai/services/utils.py index 5ff800986115..509fd6fbba28 100644 --- a/python/semantic_kernel/connectors/ai/google/vertex_ai/services/utils.py +++ b/python/semantic_kernel/connectors/ai/google/vertex_ai/services/utils.py @@ -1,12 +1,31 @@ # Copyright (c) Microsoft. All rights reserved. -from google.cloud.aiplatform_v1beta1.types.content import Blob, Candidate, Part +import logging +from typing import Any +from google.cloud.aiplatform_v1beta1.types.content import Blob, Candidate, Part +from google.cloud.aiplatform_v1beta1.types.tool import FunctionCall, FunctionResponse +from vertexai.generative_models import FunctionDeclaration, Tool, ToolConfig + +from semantic_kernel.connectors.ai.function_choice_behavior import FunctionCallChoiceConfiguration, FunctionChoiceType +from semantic_kernel.connectors.ai.google.shared_utils import ( + FUNCTION_CHOICE_TYPE_TO_GOOGLE_FUNCTION_CALLING_MODE, + format_function_result_content_name_to_gemini_function_name, + format_kernel_function_fully_qualified_name_to_gemini_function_name, +) +from semantic_kernel.connectors.ai.google.vertex_ai.vertex_ai_prompt_execution_settings import ( + VertexAIChatPromptExecutionSettings, +) from semantic_kernel.contents.chat_message_content import ChatMessageContent +from semantic_kernel.contents.function_call_content import FunctionCallContent +from semantic_kernel.contents.function_result_content import FunctionResultContent from semantic_kernel.contents.image_content import ImageContent from semantic_kernel.contents.text_content import TextContent from semantic_kernel.contents.utils.finish_reason import FinishReason as SemanticKernelFinishReason from semantic_kernel.exceptions.service_exceptions import ServiceInvalidRequestError +from semantic_kernel.functions.kernel_function_metadata import KernelFunctionMetadata + +logger: logging.Logger = logging.getLogger(__name__) def finish_reason_from_vertex_ai_to_semantic_kernel( @@ -51,11 +70,11 @@ def format_user_message(message: ChatMessageContent) -> list[Part]: # The Google AI API doesn't support images from arbitrary URIs: # https://github.com/google-gemini/generative-ai-python/issues/357 raise ServiceInvalidRequestError( - "ImageContent without data_uri in User message while formatting chat history for Google AI" + "ImageContent without data_uri in User message while formatting chat history for Vertex AI" ) else: raise ServiceInvalidRequestError( - "Unsupported item type in User message while formatting chat history for Google AI" + "Unsupported item type in User message while formatting chat history for Vertex AI" f" Inference: {type(item)}" ) @@ -71,4 +90,106 @@ def format_assistant_message(message: ChatMessageContent) -> list[Part]: Returns: The formatted assistant message as a list of parts. """ - return [Part(text=message.content)] + text_items: list[TextContent] = [] + function_call_items: list[FunctionCallContent] = [] + for item in message.items: + if isinstance(item, TextContent): + text_items.append(item) + elif isinstance(item, FunctionCallContent): + function_call_items.append(item) + else: + raise ServiceInvalidRequestError( + "Unsupported item type in Assistant message while formatting chat history for Vertex AI" + f" Inference: {type(item)}" + ) + + if len(text_items) > 1: + raise ServiceInvalidRequestError( + "Unsupported number of text items in Assistant message while formatting chat history for Vertex AI" + f" Inference: {len(text_items)}" + ) + + if len(function_call_items) > 1: + raise ServiceInvalidRequestError( + "Unsupported number of function call items in Assistant message while formatting chat history for Vertex AI" + f" Inference: {len(function_call_items)}" + ) + + part = Part() + if text_items: + part.text = text_items[0].text + if function_call_items: + part.function_call = FunctionCall( + name=function_call_items[0].name, + args=function_call_items[0].arguments, + ) + + return [part] + + +def format_tool_message(message: ChatMessageContent) -> list[Part]: + """Format a tool message to the expected object for the client. + + Args: + message: The tool message. + + Returns: + The formatted tool message. + """ + if len(message.items) != 1: + logger.warning( + "Unsupported number of items in Tool message while formatting chat history for Vertex AI: " + f"{len(message.items)}" + ) + + if not isinstance(message.items[0], FunctionResultContent): + raise ValueError("No FunctionResultContent found in the message items") + + gemini_function_name = format_function_result_content_name_to_gemini_function_name(message.items[0]) + + return [ + Part( + function_response=FunctionResponse( + name=gemini_function_name, + response={ + "name": gemini_function_name, + "content": message.items[0].result, + }, + ), + ) + ] + + +def kernel_function_metadata_to_vertex_ai_function_call_format(metadata: KernelFunctionMetadata) -> dict[str, Any]: + """Convert the kernel function metadata to function calling format.""" + return FunctionDeclaration( + name=format_kernel_function_fully_qualified_name_to_gemini_function_name(metadata), + description=metadata.description or "", + parameters={ + "type": "object", + "properties": {param.name: param.schema_data for param in metadata.parameters}, + "required": [p.name for p in metadata.parameters if p.is_required], + }, + ) + + +def update_settings_from_function_choice_configuration( + function_choice_configuration: FunctionCallChoiceConfiguration, + settings: VertexAIChatPromptExecutionSettings, + type: FunctionChoiceType, +) -> None: + """Update the settings from a FunctionChoiceConfiguration.""" + if function_choice_configuration.available_functions: + settings.tool_config = ToolConfig( + function_calling_config=ToolConfig.FunctionCallingConfig( + mode=FUNCTION_CHOICE_TYPE_TO_GOOGLE_FUNCTION_CALLING_MODE[type], + ), + ) + settings.tools = [ + Tool( + function_declarations=[ + kernel_function_metadata_to_vertex_ai_function_call_format(f) + for f in function_choice_configuration.available_functions + ] + ) + ] diff --git a/python/semantic_kernel/connectors/ai/google/vertex_ai/services/vertex_ai_chat_completion.py b/python/semantic_kernel/connectors/ai/google/vertex_ai/services/vertex_ai_chat_completion.py index 00dd251a0a9d..16c4cca58720 100644 --- a/python/semantic_kernel/connectors/ai/google/vertex_ai/services/vertex_ai_chat_completion.py +++ b/python/semantic_kernel/connectors/ai/google/vertex_ai/services/vertex_ai_chat_completion.py @@ -2,6 +2,7 @@ import sys from collections.abc import AsyncGenerator, AsyncIterable +from functools import reduce from typing import Any import vertexai @@ -9,11 +10,18 @@ from pydantic import ValidationError from vertexai.generative_models import Candidate, GenerationResponse, GenerativeModel -from semantic_kernel.connectors.ai.google.shared_utils import filter_system_message +from semantic_kernel.connectors.ai.google.shared_utils import ( + configure_function_choice_behavior, + filter_system_message, + format_gemini_function_name_to_kernel_function_fully_qualified_name, + invoke_function_calls, +) from semantic_kernel.connectors.ai.google.vertex_ai.services.utils import ( finish_reason_from_vertex_ai_to_semantic_kernel, format_assistant_message, + format_tool_message, format_user_message, + update_settings_from_function_choice_configuration, ) from semantic_kernel.connectors.ai.google.vertex_ai.services.vertex_ai_base import VertexAIBase from semantic_kernel.connectors.ai.google.vertex_ai.vertex_ai_prompt_execution_settings import ( @@ -22,11 +30,19 @@ from semantic_kernel.connectors.ai.google.vertex_ai.vertex_ai_settings import VertexAISettings from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings from semantic_kernel.contents.chat_history import ChatHistory -from semantic_kernel.contents.chat_message_content import ChatMessageContent +from semantic_kernel.contents.chat_message_content import ITEM_TYPES, ChatMessageContent +from semantic_kernel.contents.function_call_content import FunctionCallContent +from semantic_kernel.contents.streaming_chat_message_content import ITEM_TYPES as STREAMING_ITEM_TYPES from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent +from semantic_kernel.contents.streaming_text_content import StreamingTextContent +from semantic_kernel.contents.text_content import TextContent from semantic_kernel.contents.utils.author_role import AuthorRole from semantic_kernel.contents.utils.finish_reason import FinishReason -from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError +from semantic_kernel.exceptions.service_exceptions import ( + ServiceInitializationError, + ServiceInvalidExecutionSettingsError, +) +from semantic_kernel.kernel import Kernel if sys.version_info >= (3, 12): from typing import override # pragma: no cover @@ -90,7 +106,40 @@ async def get_chat_message_contents( settings = self.get_prompt_execution_settings_from_settings(settings) assert isinstance(settings, VertexAIChatPromptExecutionSettings) # nosec - return await self._send_chat_request(chat_history, settings) + if ( + settings.function_choice_behavior is None + or not settings.function_choice_behavior.auto_invoke_kernel_functions + ): + return await self._send_chat_request(chat_history, settings) + + kernel = kwargs.get("kernel") + if not isinstance(kernel, Kernel): + raise ServiceInvalidExecutionSettingsError("Kernel is required for auto invoking functions.") + + configure_function_choice_behavior(settings, kernel, update_settings_from_function_choice_configuration) + + for request_index in range(settings.function_choice_behavior.maximum_auto_invoke_attempts): + completions = await self._send_chat_request(chat_history, settings) + chat_history.add_message(message=completions[0]) + function_calls = [item for item in chat_history.messages[-1].items if isinstance(item, FunctionCallContent)] + if (fc_count := len(function_calls)) == 0: + return completions + + results = await invoke_function_calls( + function_calls=function_calls, + chat_history=chat_history, + kernel=kernel, + arguments=kwargs.get("argument", None), + function_call_count=fc_count, + request_index=request_index, + function_behavior=settings.function_choice_behavior, + ) + + if any(result.terminate for result in results if result is not None): + return completions + else: + # do a final call without auto function calling + return await self._send_chat_request(chat_history, settings) async def _send_chat_request( self, chat_history: ChatHistory, settings: VertexAIChatPromptExecutionSettings @@ -105,6 +154,8 @@ async def _send_chat_request( response: GenerationResponse = await model.generate_content_async( contents=self._prepare_chat_history_for_request(chat_history), generation_config=settings.prepare_settings_dict(), + tools=settings.tools, + tool_config=settings.tool_config, ) return [self._create_chat_message_content(response, candidate) for candidate in response.candidates] @@ -124,10 +175,26 @@ def _create_chat_message_content(self, response: GenerationResponse, candidate: response_metadata = self._get_metadata_from_response(response) response_metadata.update(self._get_metadata_from_candidate(candidate)) + items: list[ITEM_TYPES] = [] + for idx, part in enumerate(candidate.content.parts): + part_dict = part.to_dict() + if "text" in part_dict: + items.append(TextContent(text=part.text, inner_content=response, metadata=response_metadata)) + elif "function_call" in part_dict: + items.append( + FunctionCallContent( + id=f"{part.function_call.name}_{idx!s}", + name=format_gemini_function_name_to_kernel_function_fully_qualified_name( + part.function_call.name + ), + arguments={k: v for k, v in part.function_call.args.items()}, + ) + ) + return ChatMessageContent( ai_model_id=self.ai_model_id, role=AuthorRole.ASSISTANT, - content=candidate.content.parts[0].text, + items=items, inner_content=response, finish_reason=finish_reason, metadata=response_metadata, @@ -146,11 +213,68 @@ async def get_streaming_chat_message_contents( settings = self.get_prompt_execution_settings_from_settings(settings) assert isinstance(settings, VertexAIChatPromptExecutionSettings) # nosec - async_generator = self._send_chat_streaming_request(chat_history, settings) + if ( + settings.function_choice_behavior is None + or not settings.function_choice_behavior.auto_invoke_kernel_functions + ): + # No auto invoke is required. + async_generator = self._send_chat_streaming_request(chat_history, settings) + else: + # Auto invoke is required. + async_generator = self._get_streaming_chat_message_contents_auto_invoke(chat_history, settings, **kwargs) async for messages in async_generator: yield messages + async def _get_streaming_chat_message_contents_auto_invoke( + self, + chat_history: ChatHistory, + settings: VertexAIChatPromptExecutionSettings, + **kwargs: Any, + ) -> AsyncGenerator[list[StreamingChatMessageContent], Any]: + """Get streaming chat message contents from the Google AI service with auto invoking functions.""" + kernel = kwargs.get("kernel") + if not isinstance(kernel, Kernel): + raise ServiceInvalidExecutionSettingsError("Kernel is required for auto invoking functions.") + if not settings.function_choice_behavior: + raise ServiceInvalidExecutionSettingsError( + "Function choice behavior is required for auto invoking functions." + ) + + configure_function_choice_behavior(settings, kernel, update_settings_from_function_choice_configuration) + + for request_index in range(settings.function_choice_behavior.maximum_auto_invoke_attempts): + all_messages: list[StreamingChatMessageContent] = [] + function_call_returned = False + async for messages in self._send_chat_streaming_request(chat_history, settings): + for message in messages: + if message: + all_messages.append(message) + if any(isinstance(item, FunctionCallContent) for item in message.items): + function_call_returned = True + yield messages + + if not function_call_returned: + # Response doesn't contain any function calls. No need to proceed to the next request. + return + + full_completion: StreamingChatMessageContent = reduce(lambda x, y: x + y, all_messages) + function_calls = [item for item in full_completion.items if isinstance(item, FunctionCallContent)] + chat_history.add_message(message=full_completion) + + results = await invoke_function_calls( + function_calls=function_calls, + chat_history=chat_history, + kernel=kernel, + arguments=kwargs.get("argument", None), + function_call_count=len(function_calls), + request_index=request_index, + function_behavior=settings.function_choice_behavior, + ) + + if any(result.terminate for result in results if result is not None): + return + async def _send_chat_streaming_request( self, chat_history: ChatHistory, @@ -166,6 +290,8 @@ async def _send_chat_streaming_request( response: AsyncIterable[GenerationResponse] = await model.generate_content_async( contents=self._prepare_chat_history_for_request(chat_history), generation_config=settings.prepare_settings_dict(), + tools=settings.tools, + tool_config=settings.tool_config, stream=True, ) @@ -191,11 +317,34 @@ def _create_streaming_chat_message_content( response_metadata = self._get_metadata_from_response(chunk) response_metadata.update(self._get_metadata_from_candidate(candidate)) + items: list[STREAMING_ITEM_TYPES] = [] + for idx, part in enumerate(candidate.content.parts): + part_dict = part.to_dict() + if "text" in part_dict: + items.append( + StreamingTextContent( + choice_index=candidate.index, + text=part.text, + inner_content=chunk, + metadata=response_metadata, + ) + ) + elif "function_call" in part_dict: + items.append( + FunctionCallContent( + id=f"{part.function_call.name}_{idx!s}", + name=format_gemini_function_name_to_kernel_function_fully_qualified_name( + part.function_call.name + ), + arguments={k: v for k, v in part.function_call.args.items()}, + ) + ) + return StreamingChatMessageContent( ai_model_id=self.ai_model_id, role=AuthorRole.ASSISTANT, choice_index=candidate.index, - content=candidate.content.parts[0].text, + items=items, inner_content=chunk, finish_reason=finish_reason, metadata=response_metadata, @@ -221,6 +370,8 @@ def _prepare_chat_history_for_request( chat_request_messages.append(Content(role="user", parts=format_user_message(message))) elif message.role == AuthorRole.ASSISTANT: chat_request_messages.append(Content(role="model", parts=format_assistant_message(message))) + elif message.role == AuthorRole.TOOL: + chat_request_messages.append(Content(role="function", parts=format_tool_message(message))) else: raise ValueError(f"Unsupported role: {message.role}") diff --git a/python/semantic_kernel/connectors/ai/google/vertex_ai/vertex_ai_prompt_execution_settings.py b/python/semantic_kernel/connectors/ai/google/vertex_ai/vertex_ai_prompt_execution_settings.py index fb2501079666..99389e95eb9f 100644 --- a/python/semantic_kernel/connectors/ai/google/vertex_ai/vertex_ai_prompt_execution_settings.py +++ b/python/semantic_kernel/connectors/ai/google/vertex_ai/vertex_ai_prompt_execution_settings.py @@ -4,6 +4,7 @@ from typing import Any, Literal from pydantic import Field +from vertexai.generative_models import Tool, ToolConfig if sys.version_info >= (3, 12): from typing import override # pragma: no cover @@ -35,19 +36,19 @@ class VertexAITextPromptExecutionSettings(VertexAIPromptExecutionSettings): class VertexAIChatPromptExecutionSettings(VertexAIPromptExecutionSettings): """Vertex AI Chat Prompt Execution Settings.""" - tools: list[dict[str, Any]] | None = Field(None, max_length=64) - tool_choice: str | None = None + tools: list[Tool] | None = Field(None, max_length=64) + tool_config: ToolConfig | None = None @override def prepare_settings_dict(self, **kwargs) -> dict[str, Any]: """Prepare the settings as a dictionary for sending to the AI service. - This method removes the tools and tool_choice keys from the settings dictionary, as + This method removes the tools and tool_config keys from the settings dictionary, as the Vertex AI service mandates these two settings to be sent as separate parameters. """ settings_dict = super().prepare_settings_dict(**kwargs) settings_dict.pop("tools", None) - settings_dict.pop("tool_choice", None) + settings_dict.pop("tool_config", None) return settings_dict diff --git a/python/tests/integration/completions/test_chat_completions.py b/python/tests/integration/completions/test_chat_completions.py index 3386b2f76919..289b90fef745 100644 --- a/python/tests/integration/completions/test_chat_completions.py +++ b/python/tests/integration/completions/test_chat_completions.py @@ -548,6 +548,20 @@ def services() -> dict[str, tuple[ChatCompletionClientBase, type[PromptExecution ["house", "germany"], id="vertex_ai_image_input_file", ), + pytest.param( + "vertex_ai", + { + "function_choice_behavior": FunctionChoiceBehavior.Auto( + auto_invoke=True, filters={"excluded_plugins": ["chat"]} + ), + "max_tokens": 256, + }, + [ + ChatMessageContent(role=AuthorRole.USER, items=[TextContent(text="What is 3+345?")]), + ], + ["348"], + id="vertex_ai_tool_call_auto", + ), ], ) From 958e2a16b1da8d92f5eaaf91abc1200f17e93844 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Thu, 1 Aug 2024 08:41:06 -0700 Subject: [PATCH 04/13] Integration tests done --- .../ai/google/vertex_ai/services/utils.py | 6 +- .../completions/test_chat_completions.py | 78 +++++++++++++++++++ 2 files changed, 83 insertions(+), 1 deletion(-) diff --git a/python/semantic_kernel/connectors/ai/google/vertex_ai/services/utils.py b/python/semantic_kernel/connectors/ai/google/vertex_ai/services/utils.py index 509fd6fbba28..6fb1ec652a01 100644 --- a/python/semantic_kernel/connectors/ai/google/vertex_ai/services/utils.py +++ b/python/semantic_kernel/connectors/ai/google/vertex_ai/services/utils.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. +import json import logging from typing import Any @@ -119,9 +120,12 @@ def format_assistant_message(message: ChatMessageContent) -> list[Part]: if text_items: part.text = text_items[0].text if function_call_items: + # Convert the arguments to a dictionary if it is a string + args = function_call_items[0].arguments + args = json.loads(args) if isinstance(args, str) else args part.function_call = FunctionCall( name=function_call_items[0].name, - args=function_call_items[0].arguments, + args=args, ) return [part] diff --git a/python/tests/integration/completions/test_chat_completions.py b/python/tests/integration/completions/test_chat_completions.py index 289b90fef745..5f0e8671a5df 100644 --- a/python/tests/integration/completions/test_chat_completions.py +++ b/python/tests/integration/completions/test_chat_completions.py @@ -515,6 +515,45 @@ def services() -> dict[str, tuple[ChatCompletionClientBase, type[PromptExecution ["348"], id="google_ai_tool_call_auto", ), + pytest.param( + "google_ai", + { + "function_choice_behavior": FunctionChoiceBehavior.Auto( + auto_invoke=False, filters={"excluded_plugins": ["chat"]} + ) + }, + [ + ChatMessageContent(role=AuthorRole.USER, items=[TextContent(text="What is 3+345?")]), + ], + ["348"], + id="google_ai_tool_call_non_auto", + ), + pytest.param( + "google_ai", + {}, + [ + [ + ChatMessageContent( + role=AuthorRole.USER, + items=[TextContent(text="What was our 2024 revenue?")], + ), + ChatMessageContent( + role=AuthorRole.ASSISTANT, + items=[ + FunctionCallContent( + id="fin", name="finance-search", arguments='{"company": "contoso", "year": 2024}' + ) + ], + ), + ChatMessageContent( + role=AuthorRole.TOOL, + items=[FunctionResultContent(id="fin", name="finance-search", result="1.2B")], + ), + ], + ], + ["1.2"], + id="google_ai_tool_call_flow", + ), pytest.param( "vertex_ai", {}, @@ -562,6 +601,45 @@ def services() -> dict[str, tuple[ChatCompletionClientBase, type[PromptExecution ["348"], id="vertex_ai_tool_call_auto", ), + pytest.param( + "vertex_ai", + { + "function_choice_behavior": FunctionChoiceBehavior.Auto( + auto_invoke=False, filters={"excluded_plugins": ["chat"]} + ) + }, + [ + ChatMessageContent(role=AuthorRole.USER, items=[TextContent(text="What is 3+345?")]), + ], + ["348"], + id="vertex_ai_tool_call_non_auto", + ), + pytest.param( + "vertex_ai", + {}, + [ + [ + ChatMessageContent( + role=AuthorRole.USER, + items=[TextContent(text="What was our 2024 revenue?")], + ), + ChatMessageContent( + role=AuthorRole.ASSISTANT, + items=[ + FunctionCallContent( + id="fin", name="finance-search", arguments='{"company": "contoso", "year": 2024}' + ) + ], + ), + ChatMessageContent( + role=AuthorRole.TOOL, + items=[FunctionResultContent(id="fin", name="finance-search", result="1.2B")], + ), + ], + ], + ["1.2"], + id="vertex_ai_tool_call_flow", + ), ], ) From 80ab99b36b990c59c8714a43b512d11fe7e4215c Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Thu, 1 Aug 2024 14:56:31 -0700 Subject: [PATCH 05/13] Unit tests; next more coverage --- .../services/google_ai_chat_completion.py | 2 - .../ai/google/google_ai/services/utils.py | 42 +++- .../connectors/ai/google/shared_utils.py | 17 +- .../ai/google/vertex_ai/__init__.py | 19 ++ .../connectors/google/google_ai/conftest.py | 69 ++++++ .../test_google_ai_chat_completion.py | 215 +++++++++++++++++- .../connectors/google/vertex_ai/conftest.py | 62 +++++ .../test_vertex_ai_chat_completion.py | 198 +++++++++++++++- 8 files changed, 604 insertions(+), 20 deletions(-) diff --git a/python/semantic_kernel/connectors/ai/google/google_ai/services/google_ai_chat_completion.py b/python/semantic_kernel/connectors/ai/google/google_ai/services/google_ai_chat_completion.py index 2e8463b9f6f6..c52d24f69cb7 100644 --- a/python/semantic_kernel/connectors/ai/google/google_ai/services/google_ai_chat_completion.py +++ b/python/semantic_kernel/connectors/ai/google/google_ai/services/google_ai_chat_completion.py @@ -382,8 +382,6 @@ def _prepare_chat_history_for_request( chat_request_messages.append(Content(role="model", parts=format_assistant_message(message))) elif message.role == AuthorRole.TOOL: chat_request_messages.append(Content(role="function", parts=format_tool_message(message))) - else: - raise ValueError(f"Unsupported role: {message.role}") return chat_request_messages diff --git a/python/semantic_kernel/connectors/ai/google/google_ai/services/utils.py b/python/semantic_kernel/connectors/ai/google/google_ai/services/utils.py index 2570c82e0a3f..edc3a19d9dee 100644 --- a/python/semantic_kernel/connectors/ai/google/google_ai/services/utils.py +++ b/python/semantic_kernel/connectors/ai/google/google_ai/services/utils.py @@ -4,7 +4,7 @@ import logging from typing import Any -from google.generativeai.protos import Blob, Candidate, FunctionResponse, Part +from google.generativeai.protos import Blob, Candidate, FunctionCall, FunctionResponse, Part from semantic_kernel.connectors.ai.function_choice_behavior import FunctionCallChoiceConfiguration, FunctionChoiceType from semantic_kernel.connectors.ai.google.google_ai.google_ai_prompt_execution_settings import ( @@ -16,6 +16,7 @@ format_kernel_function_fully_qualified_name_to_gemini_function_name, ) from semantic_kernel.contents.chat_message_content import ChatMessageContent +from semantic_kernel.contents.function_call_content import FunctionCallContent from semantic_kernel.contents.function_result_content import FunctionResultContent from semantic_kernel.contents.image_content import ImageContent from semantic_kernel.contents.text_content import TextContent @@ -88,7 +89,44 @@ def format_assistant_message(message: ChatMessageContent) -> list[Part]: Returns: The formatted assistant message as a list of parts. """ - return [Part(text=message.content)] + text_items: list[TextContent] = [] + function_call_items: list[FunctionCallContent] = [] + for item in message.items: + if isinstance(item, TextContent): + text_items.append(item) + elif isinstance(item, FunctionCallContent): + function_call_items.append(item) + else: + raise ServiceInvalidRequestError( + "Unsupported item type in Assistant message while formatting chat history for Vertex AI" + f" Inference: {type(item)}" + ) + + if len(text_items) > 1: + raise ServiceInvalidRequestError( + "Unsupported number of text items in Assistant message while formatting chat history for Vertex AI" + f" Inference: {len(text_items)}" + ) + + if len(function_call_items) > 1: + raise ServiceInvalidRequestError( + "Unsupported number of function call items in Assistant message while formatting chat history for Vertex AI" + f" Inference: {len(function_call_items)}" + ) + + part = Part() + if text_items: + part.text = text_items[0].text + if function_call_items: + # Convert the arguments to a dictionary if it is a string + args = function_call_items[0].arguments + args = json.loads(args) if isinstance(args, str) else args + part.function_call = FunctionCall( + name=function_call_items[0].name, + args=args, + ) + + return [part] def format_tool_message(message: ChatMessageContent) -> list[Part]: diff --git a/python/semantic_kernel/connectors/ai/google/shared_utils.py b/python/semantic_kernel/connectors/ai/google/shared_utils.py index 832a86970391..bf8bdc7d84f6 100644 --- a/python/semantic_kernel/connectors/ai/google/shared_utils.py +++ b/python/semantic_kernel/connectors/ai/google/shared_utils.py @@ -3,14 +3,9 @@ import asyncio import logging from collections.abc import Callable +from typing import TYPE_CHECKING from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior, FunctionChoiceType -from semantic_kernel.connectors.ai.google.google_ai.google_ai_prompt_execution_settings import ( - GoogleAIChatPromptExecutionSettings, -) -from semantic_kernel.connectors.ai.google.vertex_ai.vertex_ai_prompt_execution_settings import ( - VertexAIChatPromptExecutionSettings, -) from semantic_kernel.contents.chat_history import ChatHistory from semantic_kernel.contents.function_call_content import FunctionCallContent from semantic_kernel.contents.function_result_content import FunctionResultContent @@ -23,6 +18,14 @@ from semantic_kernel.functions.kernel_function_metadata import KernelFunctionMetadata from semantic_kernel.kernel import Kernel +if TYPE_CHECKING: + from semantic_kernel.connectors.ai.google.google_ai.google_ai_prompt_execution_settings import ( + GoogleAIChatPromptExecutionSettings, + ) + from semantic_kernel.connectors.ai.google.vertex_ai.vertex_ai_prompt_execution_settings import ( + VertexAIChatPromptExecutionSettings, + ) + logger: logging.Logger = logging.getLogger(__name__) @@ -110,7 +113,7 @@ def format_gemini_function_name_to_kernel_function_fully_qualified_name(gemini_f def configure_function_choice_behavior( - settings: GoogleAIChatPromptExecutionSettings | VertexAIChatPromptExecutionSettings, + settings: "GoogleAIChatPromptExecutionSettings | VertexAIChatPromptExecutionSettings", kernel: Kernel, callback: Callable[..., None], ): diff --git a/python/semantic_kernel/connectors/ai/google/vertex_ai/__init__.py b/python/semantic_kernel/connectors/ai/google/vertex_ai/__init__.py index e69de29bb2d1..c524ab06ecb1 100644 --- a/python/semantic_kernel/connectors/ai/google/vertex_ai/__init__.py +++ b/python/semantic_kernel/connectors/ai/google/vertex_ai/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Microsoft. All rights reserved. + +from semantic_kernel.connectors.ai.google.vertex_ai.services.vertex_ai_chat_completion import VertexAIChatCompletion +from semantic_kernel.connectors.ai.google.vertex_ai.services.vertex_ai_text_completion import VertexAITextCompletion +from semantic_kernel.connectors.ai.google.vertex_ai.services.vertex_ai_text_embedding import VertexAITextEmbedding +from semantic_kernel.connectors.ai.google.vertex_ai.vertex_ai_prompt_execution_settings import ( + VertexAIChatPromptExecutionSettings, + VertexAIEmbeddingPromptExecutionSettings, + VertexAIPromptExecutionSettings, +) + +__all__ = [ + "VertexAIChatCompletion", + "VertexAIChatPromptExecutionSettings", + "VertexAIEmbeddingPromptExecutionSettings", + "VertexAIPromptExecutionSettings", + "VertexAITextCompletion", + "VertexAITextEmbedding", +] diff --git a/python/tests/unit/connectors/google/google_ai/conftest.py b/python/tests/unit/connectors/google/google_ai/conftest.py index 1318528d0943..7f27c217713b 100644 --- a/python/tests/unit/connectors/google/google_ai/conftest.py +++ b/python/tests/unit/connectors/google/google_ai/conftest.py @@ -59,6 +59,40 @@ def mock_google_ai_chat_completion_response() -> AsyncGenerateContentResponse: ) +@pytest.fixture() +def mock_google_ai_chat_completion_response_with_tool_call() -> AsyncGenerateContentResponse: + """Mock Google AI Chat Completion response.""" + candidate = protos.Candidate() + candidate.index = 0 + candidate.content = protos.Content( + role="user", + parts=[ + protos.Part( + function_call=protos.FunctionCall( + name="test_function", + args={"test_arg": "test_value"}, + ) + ) + ], + ) + candidate.finish_reason = protos.Candidate.FinishReason.STOP + + response = protos.GenerateContentResponse() + response.candidates.append(candidate) + response.usage_metadata = protos.GenerateContentResponse.UsageMetadata( + prompt_token_count=0, + cached_content_token_count=0, + candidates_token_count=0, + total_token_count=0, + ) + + return AsyncGenerateContentResponse( + done=True, + iterator=None, + result=response, + ) + + @pytest_asyncio.fixture() async def mock_google_ai_streaming_chat_completion_response() -> AsyncGenerateContentResponse: """Mock Google AI streaming Chat Completion response.""" @@ -84,6 +118,41 @@ async def mock_google_ai_streaming_chat_completion_response() -> AsyncGenerateCo ) +@pytest_asyncio.fixture() +async def mock_google_ai_streaming_chat_completion_response_with_tool_call() -> AsyncGenerateContentResponse: + """Mock Google AI streaming Chat Completion response with tool call.""" + candidate = protos.Candidate() + candidate.index = 0 + candidate.content = protos.Content( + role="user", + parts=[ + protos.Part( + function_call=protos.FunctionCall( + name="test_function", + args={"test_arg": "test_value"}, + ) + ) + ], + ) + candidate.finish_reason = protos.Candidate.FinishReason.STOP + + response = protos.GenerateContentResponse() + response.candidates.append(candidate) + response.usage_metadata = protos.GenerateContentResponse.UsageMetadata( + prompt_token_count=0, + cached_content_token_count=0, + candidates_token_count=0, + total_token_count=0, + ) + + iterable = MagicMock(spec=AsyncGenerator) + iterable.__aiter__.return_value = [response] + + return await AsyncGenerateContentResponse.from_aiterator( + iterator=iterable, + ) + + @pytest.fixture() def mock_google_ai_text_completion_response() -> AsyncGenerateContentResponse: """Mock Google AI Text Completion response.""" diff --git a/python/tests/unit/connectors/google/google_ai/services/test_google_ai_chat_completion.py b/python/tests/unit/connectors/google/google_ai/services/test_google_ai_chat_completion.py index 17598d198e61..78580c90c901 100644 --- a/python/tests/unit/connectors/google/google_ai/services/test_google_ai_chat_completion.py +++ b/python/tests/unit/connectors/google/google_ai/services/test_google_ai_chat_completion.py @@ -7,6 +7,7 @@ from google.generativeai.protos import Content from google.generativeai.types import GenerationConfig +from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior from semantic_kernel.connectors.ai.google.google_ai.google_ai_prompt_execution_settings import ( GoogleAIChatPromptExecutionSettings, ) @@ -15,7 +16,10 @@ from semantic_kernel.contents.chat_history import ChatHistory from semantic_kernel.contents.chat_message_content import ChatMessageContent from semantic_kernel.contents.utils.finish_reason import FinishReason -from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError +from semantic_kernel.exceptions.service_exceptions import ( + ServiceInitializationError, + ServiceInvalidExecutionSettingsError, +) # region init @@ -74,7 +78,7 @@ def test_prompt_execution_settings_class(google_ai_unit_test_env) -> None: @pytest.mark.asyncio @patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock) async def test_google_ai_chat_completion( - mock_google_model_generate_content_async, + mock_google_ai_model_generate_content_async, google_ai_unit_test_env, chat_history: ChatHistory, mock_google_ai_chat_completion_response, @@ -82,16 +86,18 @@ async def test_google_ai_chat_completion( """Test chat completion with GoogleAIChatCompletion""" settings = GoogleAIChatPromptExecutionSettings() - mock_google_model_generate_content_async.return_value = mock_google_ai_chat_completion_response + mock_google_ai_model_generate_content_async.return_value = mock_google_ai_chat_completion_response google_ai_chat_completion = GoogleAIChatCompletion() responses: list[ChatMessageContent] = await google_ai_chat_completion.get_chat_message_contents( chat_history, settings ) - mock_google_model_generate_content_async.assert_called_once_with( + mock_google_ai_model_generate_content_async.assert_called_once_with( contents=google_ai_chat_completion._prepare_chat_history_for_request(chat_history), generation_config=GenerationConfig(**settings.prepare_settings_dict()), + tools=None, + tool_config=None, ) assert len(responses) == 1 assert responses[0].role == "assistant" @@ -102,6 +108,100 @@ async def test_google_ai_chat_completion( assert responses[0].inner_content == mock_google_ai_chat_completion_response +@pytest.mark.asyncio +@patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock) +async def test_google_ai_chat_completion_with_function_choice_behavior_fail_verification( + chat_history: ChatHistory, +) -> None: + """Test completion of GoogleAIChatCompletion with function choice behavior expect verification failure""" + + # Missing kernel + with pytest.raises(ServiceInvalidExecutionSettingsError): + settings = GoogleAIChatPromptExecutionSettings( + function_choice_behavior=FunctionChoiceBehavior.Auto(), + ) + + google_ai_chat_completion = GoogleAIChatCompletion() + + await google_ai_chat_completion.get_chat_message_contents( + chat_history=chat_history, + settings=settings, + ) + + +@pytest.mark.asyncio +@patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock) +async def test_google_ai_chat_completion_with_function_choice_behavior( + mock_google_ai_model_generate_content_async, + kernel, + chat_history: ChatHistory, + mock_google_ai_chat_completion_response_with_tool_call, +) -> None: + """Test completion of GoogleAIChatCompletion with function choice behavior""" + mock_google_ai_model_generate_content_async.return_value = mock_google_ai_chat_completion_response_with_tool_call + + settings = GoogleAIChatPromptExecutionSettings( + function_choice_behavior=FunctionChoiceBehavior.Auto(), + ) + settings.function_choice_behavior.maximum_auto_invoke_attempts = 1 + + google_ai_chat_completion = GoogleAIChatCompletion() + + responses = await google_ai_chat_completion.get_chat_message_contents( + chat_history=chat_history, + settings=settings, + kernel=kernel, + ) + + # The function should be called twice: + # One for the tool call and one for the last completion + # after the maximum_auto_invoke_attempts is reached + assert mock_google_ai_model_generate_content_async.call_count == 2 + assert len(responses) == 1 + assert responses[0].role == "assistant" + # Google doesn't return STOP as the finish reason for tool calls + assert responses[0].finish_reason == FinishReason.STOP + + +@pytest.mark.asyncio +@patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock) +async def test_google_ai_chat_completion_with_function_choice_behavior_no_tool_call( + mock_google_ai_model_generate_content_async, + kernel, + chat_history: ChatHistory, + mock_google_ai_chat_completion_response, +) -> None: + """Test completion of GoogleAIChatCompletion with function choice behavior but no tool call returned""" + mock_google_ai_model_generate_content_async.return_value = mock_google_ai_chat_completion_response + + settings = GoogleAIChatPromptExecutionSettings( + function_choice_behavior=FunctionChoiceBehavior.Auto(), + ) + settings.function_choice_behavior.maximum_auto_invoke_attempts = 1 + + google_ai_chat_completion = GoogleAIChatCompletion() + + responses = await google_ai_chat_completion.get_chat_message_contents( + chat_history=chat_history, + settings=settings, + kernel=kernel, + ) + + # Remove the latest message since the response from the model will be added to the chat history + # even when the model doesn't return a tool call + chat_history.remove_message(chat_history[-1]) + + mock_google_ai_model_generate_content_async.assert_awaited_once_with( + contents=google_ai_chat_completion._prepare_chat_history_for_request(chat_history), + generation_config=GenerationConfig(**settings.prepare_settings_dict()), + tools=None, + tool_config=None, + ) + assert len(responses) == 1 + assert responses[0].role == "assistant" + assert responses[0].content == mock_google_ai_chat_completion_response.candidates[0].content.parts[0].text + + # endregion chat completion @@ -109,15 +209,15 @@ async def test_google_ai_chat_completion( @pytest.mark.asyncio @patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock) async def test_google_ai_streaming_chat_completion( - mock_google_model_generate_content_async, + mock_google_ai_model_generate_content_async, google_ai_unit_test_env, chat_history: ChatHistory, mock_google_ai_streaming_chat_completion_response, ) -> None: - """Test chat completion with GoogleAIChatCompletion""" + """Test streaming chat completion with GoogleAIChatCompletion""" settings = GoogleAIChatPromptExecutionSettings() - mock_google_model_generate_content_async.return_value = mock_google_ai_streaming_chat_completion_response + mock_google_ai_model_generate_content_async.return_value = mock_google_ai_streaming_chat_completion_response google_ai_chat_completion = GoogleAIChatCompletion() async for messages in google_ai_chat_completion.get_streaming_chat_message_contents(chat_history, settings): @@ -127,9 +227,108 @@ async def test_google_ai_streaming_chat_completion( assert "usage" in messages[0].metadata assert "prompt_feedback" in messages[0].metadata - mock_google_model_generate_content_async.assert_called_once_with( + mock_google_ai_model_generate_content_async.assert_called_once_with( + contents=google_ai_chat_completion._prepare_chat_history_for_request(chat_history), + generation_config=GenerationConfig(**settings.prepare_settings_dict()), + tools=None, + tool_config=None, + stream=True, + ) + + +@pytest.mark.asyncio +@patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock) +async def test_google_ai_streaming_chat_completion_with_function_choice_behavior_fail_verification( + chat_history: ChatHistory, +) -> None: + """Test streaming chat completion of GoogleAIChatCompletion with function choice + behavior expect verification failure""" + + # Missing kernel + with pytest.raises(ServiceInvalidExecutionSettingsError): + settings = GoogleAIChatPromptExecutionSettings( + function_choice_behavior=FunctionChoiceBehavior.Auto(), + ) + + google_ai_chat_completion = GoogleAIChatCompletion() + + async for _ in google_ai_chat_completion.get_streaming_chat_message_contents( + chat_history=chat_history, + settings=settings, + ): + pass + + +@pytest.mark.asyncio +@patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock) +async def test_google_ai_streaming_chat_completion_with_function_choice_behavior( + mock_google_ai_model_generate_content_async, + kernel, + chat_history: ChatHistory, + mock_google_ai_streaming_chat_completion_response_with_tool_call, +) -> None: + """Test streaming chat completion of GoogleAIChatCompletion with function choice behavior""" + mock_google_ai_model_generate_content_async.return_value = ( + mock_google_ai_streaming_chat_completion_response_with_tool_call + ) + + settings = GoogleAIChatPromptExecutionSettings( + function_choice_behavior=FunctionChoiceBehavior.Auto(), + ) + settings.function_choice_behavior.maximum_auto_invoke_attempts = 1 + + google_ai_chat_completion = GoogleAIChatCompletion() + + async for messages in google_ai_chat_completion.get_streaming_chat_message_contents( + chat_history, + settings, + kernel=kernel, + ): + assert len(messages) == 1 + assert messages[0].role == "assistant" + assert messages[0].content == "" + # Google doesn't return STOP as the finish reason for tool calls + assert messages[0].finish_reason == FinishReason.STOP + + # Streaming completion with tool call does not invoke the model + # after maximum_auto_invoke_attempts is reached + assert mock_google_ai_model_generate_content_async.call_count == 1 + + +@pytest.mark.asyncio +@patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock) +async def test_google_ai_streaming_chat_completion_with_function_choice_behavior_no_tool_call( + mock_google_ai_model_generate_content_async, + kernel, + chat_history: ChatHistory, + mock_google_ai_streaming_chat_completion_response, +) -> None: + """Test completion of GoogleAIChatCompletion with function choice behavior but no tool call returned""" + mock_google_ai_model_generate_content_async.return_value = mock_google_ai_streaming_chat_completion_response + + settings = GoogleAIChatPromptExecutionSettings( + function_choice_behavior=FunctionChoiceBehavior.Auto(), + ) + settings.function_choice_behavior.maximum_auto_invoke_attempts = 1 + + google_ai_chat_completion = GoogleAIChatCompletion() + + async for messages in google_ai_chat_completion.get_streaming_chat_message_contents( + chat_history=chat_history, + settings=settings, + kernel=kernel, + ): + assert len(messages) == 1 + assert messages[0].role == "assistant" + assert ( + messages[0].content == mock_google_ai_streaming_chat_completion_response.candidates[0].content.parts[0].text + ) + + mock_google_ai_model_generate_content_async.assert_awaited_once_with( contents=google_ai_chat_completion._prepare_chat_history_for_request(chat_history), generation_config=GenerationConfig(**settings.prepare_settings_dict()), + tools=None, + tool_config=None, stream=True, ) diff --git a/python/tests/unit/connectors/google/vertex_ai/conftest.py b/python/tests/unit/connectors/google/vertex_ai/conftest.py index 3d999ae3f7b4..d1efbd80b19a 100644 --- a/python/tests/unit/connectors/google/vertex_ai/conftest.py +++ b/python/tests/unit/connectors/google/vertex_ai/conftest.py @@ -6,6 +6,7 @@ import pytest from google.cloud.aiplatform_v1beta1.types.content import Candidate, Content, Part from google.cloud.aiplatform_v1beta1.types.prediction_service import GenerateContentResponse +from google.cloud.aiplatform_v1beta1.types.tool import FunctionCall from vertexai.generative_models import GenerationResponse from vertexai.language_models import TextEmbedding @@ -55,6 +56,35 @@ def mock_vertex_ai_chat_completion_response() -> GenerationResponse: return GenerationResponse._from_gapic(response) +@pytest.fixture() +def mock_vertex_ai_chat_completion_response_with_tool_call() -> GenerationResponse: + """Mock Vertex AI Chat Completion response.""" + candidate = Candidate() + candidate.index = 0 + candidate.content = Content( + role="user", + parts=[ + Part( + function_call=FunctionCall( + name="test_function", + args={"test_arg": "test_value"}, + ) + ) + ], + ) + candidate.finish_reason = Candidate.FinishReason.STOP + + response = GenerateContentResponse() + response.candidates.append(candidate) + response.usage_metadata = GenerateContentResponse.UsageMetadata( + prompt_token_count=0, + candidates_token_count=0, + total_token_count=0, + ) + + return GenerationResponse._from_gapic(response) + + @pytest.fixture() def mock_vertex_ai_streaming_chat_completion_response() -> AsyncIterable[GenerationResponse]: """Mock Vertex AI streaming Chat Completion response.""" @@ -77,6 +107,38 @@ def mock_vertex_ai_streaming_chat_completion_response() -> AsyncIterable[Generat return iterable +@pytest.fixture() +def mock_vertex_ai_streaming_chat_completion_response_with_tool_call() -> AsyncIterable[GenerationResponse]: + """Mock Vertex AI streaming Chat Completion response.""" + candidate = Candidate() + candidate.index = 0 + candidate.content = Content( + role="user", + parts=[ + Part( + function_call=FunctionCall( + name="test_function", + args={"test_arg": "test_value"}, + ) + ) + ], + ) + candidate.finish_reason = Candidate.FinishReason.STOP + + response = GenerateContentResponse() + response.candidates.append(candidate) + response.usage_metadata = GenerateContentResponse.UsageMetadata( + prompt_token_count=0, + candidates_token_count=0, + total_token_count=0, + ) + + iterable = MagicMock(spec=AsyncGenerator) + iterable.__aiter__.return_value = [GenerationResponse._from_gapic(response)] + + return iterable + + @pytest.fixture() def mock_vertex_ai_text_completion_response() -> GenerationResponse: """Mock Vertex AI Text Completion response.""" diff --git a/python/tests/unit/connectors/google/vertex_ai/services/test_vertex_ai_chat_completion.py b/python/tests/unit/connectors/google/vertex_ai/services/test_vertex_ai_chat_completion.py index 113ee470b2d1..957564fcd169 100644 --- a/python/tests/unit/connectors/google/vertex_ai/services/test_vertex_ai_chat_completion.py +++ b/python/tests/unit/connectors/google/vertex_ai/services/test_vertex_ai_chat_completion.py @@ -7,6 +7,7 @@ from google.cloud.aiplatform_v1beta1.types.content import Content from vertexai.generative_models import GenerativeModel +from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior from semantic_kernel.connectors.ai.google.vertex_ai.services.vertex_ai_chat_completion import VertexAIChatCompletion from semantic_kernel.connectors.ai.google.vertex_ai.vertex_ai_prompt_execution_settings import ( VertexAIChatPromptExecutionSettings, @@ -15,7 +16,10 @@ from semantic_kernel.contents.chat_history import ChatHistory from semantic_kernel.contents.chat_message_content import ChatMessageContent from semantic_kernel.contents.utils.finish_reason import FinishReason -from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError +from semantic_kernel.exceptions.service_exceptions import ( + ServiceInitializationError, + ServiceInvalidExecutionSettingsError, +) # region init @@ -92,6 +96,8 @@ async def test_vertex_ai_chat_completion( mock_vertex_ai_model_generate_content_async.assert_called_once_with( contents=vertex_ai_chat_completion._prepare_chat_history_for_request(chat_history), generation_config=settings.prepare_settings_dict(), + tools=None, + tool_config=None, ) assert len(responses) == 1 assert responses[0].role == "assistant" @@ -102,6 +108,100 @@ async def test_vertex_ai_chat_completion( assert responses[0].inner_content == mock_vertex_ai_chat_completion_response +@pytest.mark.asyncio +@patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock) +async def test_vertex_ai_chat_completion_with_function_choice_behavior_fail_verification( + chat_history: ChatHistory, +) -> None: + """Test completion of VertexAIChatCompletion with function choice behavior expect verification failure""" + + # Missing kernel + with pytest.raises(ServiceInvalidExecutionSettingsError): + settings = VertexAIChatPromptExecutionSettings( + function_choice_behavior=FunctionChoiceBehavior.Auto(), + ) + + vertex_ai_chat_completion = VertexAIChatCompletion() + + await vertex_ai_chat_completion.get_chat_message_contents( + chat_history=chat_history, + settings=settings, + ) + + +@pytest.mark.asyncio +@patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock) +async def test_vertex_ai_chat_completion_with_function_choice_behavior( + mock_vertex_ai_model_generate_content_async, + kernel, + chat_history: ChatHistory, + mock_vertex_ai_chat_completion_response_with_tool_call, +) -> None: + """Test completion of VertexAIChatCompletion with function choice behavior""" + mock_vertex_ai_model_generate_content_async.return_value = mock_vertex_ai_chat_completion_response_with_tool_call + + settings = VertexAIChatPromptExecutionSettings( + function_choice_behavior=FunctionChoiceBehavior.Auto(), + ) + settings.function_choice_behavior.maximum_auto_invoke_attempts = 1 + + vertex_ai_chat_completion = VertexAIChatCompletion() + + responses = await vertex_ai_chat_completion.get_chat_message_contents( + chat_history=chat_history, + settings=settings, + kernel=kernel, + ) + + # The function should be called twice: + # One for the tool call and one for the last completion + # after the maximum_auto_invoke_attempts is reached + assert mock_vertex_ai_model_generate_content_async.call_count == 2 + assert len(responses) == 1 + assert responses[0].role == "assistant" + # Google doesn't return STOP as the finish reason for tool calls + assert responses[0].finish_reason == FinishReason.STOP + + +@pytest.mark.asyncio +@patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock) +async def test_vertex_ai_chat_completion_with_function_choice_behavior_no_tool_call( + mock_vertex_ai_model_generate_content_async, + kernel, + chat_history: ChatHistory, + mock_vertex_ai_chat_completion_response, +) -> None: + """Test completion of VertexAIChatCompletion with function choice behavior but no tool call returned""" + mock_vertex_ai_model_generate_content_async.return_value = mock_vertex_ai_chat_completion_response + + settings = VertexAIChatPromptExecutionSettings( + function_choice_behavior=FunctionChoiceBehavior.Auto(), + ) + settings.function_choice_behavior.maximum_auto_invoke_attempts = 1 + + vertex_ai_chat_completion = VertexAIChatCompletion() + + responses = await vertex_ai_chat_completion.get_chat_message_contents( + chat_history=chat_history, + settings=settings, + kernel=kernel, + ) + + # Remove the latest message since the response from the model will be added to the chat history + # even when the model doesn't return a tool call + chat_history.remove_message(chat_history[-1]) + + mock_vertex_ai_model_generate_content_async.assert_awaited_once_with( + contents=vertex_ai_chat_completion._prepare_chat_history_for_request(chat_history), + generation_config=settings.prepare_settings_dict(), + tools=None, + tool_config=None, + ) + assert len(responses) == 1 + assert responses[0].role == "assistant" + assert responses[0].content == mock_vertex_ai_chat_completion_response.candidates[0].content.parts[0].text + + # endregion chat completion @@ -130,6 +230,102 @@ async def test_vertex_ai_streaming_chat_completion( mock_vertex_ai_model_generate_content_async.assert_called_once_with( contents=vertex_ai_chat_completion._prepare_chat_history_for_request(chat_history), generation_config=settings.prepare_settings_dict(), + tools=None, + tool_config=None, + stream=True, + ) + + +@pytest.mark.asyncio +@patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock) +async def test_vertex_ai_streaming_chat_completion_with_function_choice_behavior_fail_verification( + chat_history: ChatHistory, +) -> None: + """Test streaming chat completion of VertexAIChatCompletion with function choice + behavior expect verification failure""" + + # Missing kernel + with pytest.raises(ServiceInvalidExecutionSettingsError): + settings = VertexAIChatPromptExecutionSettings( + function_choice_behavior=FunctionChoiceBehavior.Auto(), + ) + + vertex_ai_chat_completion = VertexAIChatCompletion() + + async for _ in vertex_ai_chat_completion.get_streaming_chat_message_contents( + chat_history=chat_history, + settings=settings, + ): + pass + + +@pytest.mark.asyncio +@patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock) +async def test_vertex_ai_streaming_chat_completion_with_function_choice_behavior( + mock_vertex_ai_model_generate_content_async, + kernel, + chat_history: ChatHistory, + mock_vertex_ai_streaming_chat_completion_response_with_tool_call, +) -> None: + """Test streaming chat completion of VertexAIChatCompletion with function choice behavior""" + mock_vertex_ai_model_generate_content_async.return_value = ( + mock_vertex_ai_streaming_chat_completion_response_with_tool_call + ) + + settings = VertexAIChatPromptExecutionSettings( + function_choice_behavior=FunctionChoiceBehavior.Auto(), + ) + settings.function_choice_behavior.maximum_auto_invoke_attempts = 1 + + vertex_ai_chat_completion = VertexAIChatCompletion() + + async for messages in vertex_ai_chat_completion.get_streaming_chat_message_contents( + chat_history, + settings, + kernel=kernel, + ): + assert len(messages) == 1 + assert messages[0].role == "assistant" + assert messages[0].content == "" + # Google doesn't return STOP as the finish reason for tool calls + assert messages[0].finish_reason == FinishReason.STOP + + # Streaming completion with tool call does not invoke the model + # after maximum_auto_invoke_attempts is reached + assert mock_vertex_ai_model_generate_content_async.call_count == 1 + + +@pytest.mark.asyncio +@patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock) +async def test_vertex_ai_streaming_chat_completion_with_function_choice_behavior_no_tool_call( + mock_vertex_ai_model_generate_content_async, + kernel, + chat_history: ChatHistory, + mock_vertex_ai_streaming_chat_completion_response, +) -> None: + """Test completion of VertexAIChatCompletion with function choice behavior but no tool call returned""" + mock_vertex_ai_model_generate_content_async.return_value = mock_vertex_ai_streaming_chat_completion_response + + settings = VertexAIChatPromptExecutionSettings( + function_choice_behavior=FunctionChoiceBehavior.Auto(), + ) + settings.function_choice_behavior.maximum_auto_invoke_attempts = 1 + + vertex_ai_chat_completion = VertexAIChatCompletion() + + async for messages in vertex_ai_chat_completion.get_streaming_chat_message_contents( + chat_history=chat_history, + settings=settings, + kernel=kernel, + ): + assert len(messages) == 1 + assert messages[0].role == "assistant" + + mock_vertex_ai_model_generate_content_async.assert_awaited_once_with( + contents=vertex_ai_chat_completion._prepare_chat_history_for_request(chat_history), + generation_config=settings.prepare_settings_dict(), + tools=None, + tool_config=None, stream=True, ) From 17105054c7c0b3a0e264afba8070a088bc8b3405 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Thu, 1 Aug 2024 18:51:58 -0700 Subject: [PATCH 06/13] Improve coverage --- .../ai/google/google_ai/services/utils.py | 103 +++++++----------- .../ai/google/vertex_ai/services/utils.py | 103 +++++++----------- .../services/vertex_ai_chat_completion.py | 2 - .../test_google_ai_chat_completion.py | 11 -- .../services/test_google_ai_utils.py | 39 ++++++- .../connectors/google/test_shared_utils.py | 26 ++++- .../test_vertex_ai_chat_completion.py | 11 -- .../services/test_vertex_ai_utils.py | 39 ++++++- 8 files changed, 181 insertions(+), 153 deletions(-) diff --git a/python/semantic_kernel/connectors/ai/google/google_ai/services/utils.py b/python/semantic_kernel/connectors/ai/google/google_ai/services/utils.py index edc3a19d9dee..5264358e90e0 100644 --- a/python/semantic_kernel/connectors/ai/google/google_ai/services/utils.py +++ b/python/semantic_kernel/connectors/ai/google/google_ai/services/utils.py @@ -55,22 +55,12 @@ def format_user_message(message: ChatMessageContent) -> list[Part]: Returns: The formatted user message as a list of parts. """ - if not any(isinstance(item, (ImageContent)) for item in message.items): - return [Part(text=message.content)] - parts: list[Part] = [] for item in message.items: if isinstance(item, TextContent): parts.append(Part(text=message.content)) elif isinstance(item, ImageContent): - if item.data_uri: - parts.append(Part(inline_data=Blob(mime_type=item.mime_type, data=item.data))) - else: - # The Google AI API doesn't support images from arbitrary URIs: - # https://github.com/google-gemini/generative-ai-python/issues/357 - raise ServiceInvalidRequestError( - "ImageContent without data_uri in User message while formatting chat history for Google AI" - ) + parts.append(_create_image_part(item)) else: raise ServiceInvalidRequestError( "Unsupported item type in User message while formatting chat history for Google AI" @@ -89,44 +79,29 @@ def format_assistant_message(message: ChatMessageContent) -> list[Part]: Returns: The formatted assistant message as a list of parts. """ - text_items: list[TextContent] = [] - function_call_items: list[FunctionCallContent] = [] + parts: list[Part] = [] for item in message.items: if isinstance(item, TextContent): - text_items.append(item) + parts.append(Part(text=item.text)) elif isinstance(item, FunctionCallContent): - function_call_items.append(item) + parts.append( + Part( + function_call=FunctionCall( + name=item.name, + # Convert the arguments to a dictionary if it is a string + args=json.loads(item.arguments) if isinstance(item.arguments, str) else item.arguments, + ) + ) + ) + elif isinstance(item, ImageContent): + parts.append(_create_image_part(item)) else: raise ServiceInvalidRequestError( "Unsupported item type in Assistant message while formatting chat history for Vertex AI" f" Inference: {type(item)}" ) - if len(text_items) > 1: - raise ServiceInvalidRequestError( - "Unsupported number of text items in Assistant message while formatting chat history for Vertex AI" - f" Inference: {len(text_items)}" - ) - - if len(function_call_items) > 1: - raise ServiceInvalidRequestError( - "Unsupported number of function call items in Assistant message while formatting chat history for Vertex AI" - f" Inference: {len(function_call_items)}" - ) - - part = Part() - if text_items: - part.text = text_items[0].text - if function_call_items: - # Convert the arguments to a dictionary if it is a string - args = function_call_items[0].arguments - args = json.loads(args) if isinstance(args, str) else args - part.function_call = FunctionCall( - name=function_call_items[0].name, - args=args, - ) - - return [part] + return parts def format_tool_message(message: ChatMessageContent) -> list[Part]: @@ -138,28 +113,23 @@ def format_tool_message(message: ChatMessageContent) -> list[Part]: Returns: The formatted tool message. """ - if len(message.items) != 1: - logger.warning( - "Unsupported number of items in Tool message while formatting chat history for Google AI: " - f"{len(message.items)}" - ) - - if not isinstance(message.items[0], FunctionResultContent): - raise ValueError("No FunctionResultContent found in the message items") - - gemini_function_name = format_function_result_content_name_to_gemini_function_name(message.items[0]) - - return [ - Part( - function_response=FunctionResponse( - name=gemini_function_name, - response={ - "name": gemini_function_name, - "content": json.dumps(message.items[0].result), - }, + parts: list[Part] = [] + for item in message.items: + if isinstance(item, FunctionResultContent): + gemini_function_name = format_function_result_content_name_to_gemini_function_name(item) + parts.append( + Part( + function_response=FunctionResponse( + name=gemini_function_name, + response={ + "name": gemini_function_name, + "content": item.result, + }, + ) + ) ) - ) - ] + + return parts def kernel_function_metadata_to_google_ai_function_call_format(metadata: KernelFunctionMetadata) -> dict[str, Any]: @@ -195,3 +165,14 @@ def update_settings_from_function_choice_configuration( ] } ] + + +def _create_image_part(image_content: ImageContent) -> Part: + if image_content.data_uri: + return Part(inline_data=Blob(mime_type=image_content.mime_type, data=image_content.data)) + + # The Google AI API doesn't support images from arbitrary URIs: + # https://github.com/google-gemini/generative-ai-python/issues/357 + raise ServiceInvalidRequestError( + "ImageContent without data_uri in User message while formatting chat history for Google AI" + ) diff --git a/python/semantic_kernel/connectors/ai/google/vertex_ai/services/utils.py b/python/semantic_kernel/connectors/ai/google/vertex_ai/services/utils.py index 6fb1ec652a01..0f3485ed8e51 100644 --- a/python/semantic_kernel/connectors/ai/google/vertex_ai/services/utils.py +++ b/python/semantic_kernel/connectors/ai/google/vertex_ai/services/utils.py @@ -57,22 +57,12 @@ def format_user_message(message: ChatMessageContent) -> list[Part]: Returns: The formatted user message as a list of parts. """ - if not any(isinstance(item, (ImageContent)) for item in message.items): - return [Part(text=message.content)] - parts: list[Part] = [] for item in message.items: if isinstance(item, TextContent): parts.append(Part(text=message.content)) elif isinstance(item, ImageContent): - if item.data_uri: - parts.append(Part(inline_data=Blob(mime_type=item.mime_type, data=item.data))) - else: - # The Google AI API doesn't support images from arbitrary URIs: - # https://github.com/google-gemini/generative-ai-python/issues/357 - raise ServiceInvalidRequestError( - "ImageContent without data_uri in User message while formatting chat history for Vertex AI" - ) + parts.append(_create_image_part(item)) else: raise ServiceInvalidRequestError( "Unsupported item type in User message while formatting chat history for Vertex AI" @@ -91,44 +81,29 @@ def format_assistant_message(message: ChatMessageContent) -> list[Part]: Returns: The formatted assistant message as a list of parts. """ - text_items: list[TextContent] = [] - function_call_items: list[FunctionCallContent] = [] + parts: list[Part] = [] for item in message.items: if isinstance(item, TextContent): - text_items.append(item) + parts.append(Part(text=item.text)) elif isinstance(item, FunctionCallContent): - function_call_items.append(item) + parts.append( + Part( + function_call=FunctionCall( + name=item.name, + # Convert the arguments to a dictionary if it is a string + args=json.loads(item.arguments) if isinstance(item.arguments, str) else item.arguments, + ) + ) + ) + elif isinstance(item, ImageContent): + parts.append(_create_image_part(item)) else: raise ServiceInvalidRequestError( "Unsupported item type in Assistant message while formatting chat history for Vertex AI" f" Inference: {type(item)}" ) - if len(text_items) > 1: - raise ServiceInvalidRequestError( - "Unsupported number of text items in Assistant message while formatting chat history for Vertex AI" - f" Inference: {len(text_items)}" - ) - - if len(function_call_items) > 1: - raise ServiceInvalidRequestError( - "Unsupported number of function call items in Assistant message while formatting chat history for Vertex AI" - f" Inference: {len(function_call_items)}" - ) - - part = Part() - if text_items: - part.text = text_items[0].text - if function_call_items: - # Convert the arguments to a dictionary if it is a string - args = function_call_items[0].arguments - args = json.loads(args) if isinstance(args, str) else args - part.function_call = FunctionCall( - name=function_call_items[0].name, - args=args, - ) - - return [part] + return parts def format_tool_message(message: ChatMessageContent) -> list[Part]: @@ -140,28 +115,23 @@ def format_tool_message(message: ChatMessageContent) -> list[Part]: Returns: The formatted tool message. """ - if len(message.items) != 1: - logger.warning( - "Unsupported number of items in Tool message while formatting chat history for Vertex AI: " - f"{len(message.items)}" - ) - - if not isinstance(message.items[0], FunctionResultContent): - raise ValueError("No FunctionResultContent found in the message items") - - gemini_function_name = format_function_result_content_name_to_gemini_function_name(message.items[0]) + parts: list[Part] = [] + for item in message.items: + if isinstance(item, FunctionResultContent): + gemini_function_name = format_function_result_content_name_to_gemini_function_name(item) + parts.append( + Part( + function_response=FunctionResponse( + name=gemini_function_name, + response={ + "name": gemini_function_name, + "content": item.result, + }, + ) + ) + ) - return [ - Part( - function_response=FunctionResponse( - name=gemini_function_name, - response={ - "name": gemini_function_name, - "content": message.items[0].result, - }, - ), - ) - ] + return parts def kernel_function_metadata_to_vertex_ai_function_call_format(metadata: KernelFunctionMetadata) -> dict[str, Any]: @@ -197,3 +167,14 @@ def update_settings_from_function_choice_configuration( ] ) ] + + +def _create_image_part(image_content: ImageContent) -> Part: + if image_content.data_uri: + return Part(inline_data=Blob(mime_type=image_content.mime_type, data=image_content.data)) + + # The Google AI API doesn't support images from arbitrary URIs: + # https://github.com/google-gemini/generative-ai-python/issues/357 + raise ServiceInvalidRequestError( + "ImageContent without data_uri in User message while formatting chat history for Google AI" + ) diff --git a/python/semantic_kernel/connectors/ai/google/vertex_ai/services/vertex_ai_chat_completion.py b/python/semantic_kernel/connectors/ai/google/vertex_ai/services/vertex_ai_chat_completion.py index 16c4cca58720..1696b7a67e4d 100644 --- a/python/semantic_kernel/connectors/ai/google/vertex_ai/services/vertex_ai_chat_completion.py +++ b/python/semantic_kernel/connectors/ai/google/vertex_ai/services/vertex_ai_chat_completion.py @@ -372,8 +372,6 @@ def _prepare_chat_history_for_request( chat_request_messages.append(Content(role="model", parts=format_assistant_message(message))) elif message.role == AuthorRole.TOOL: chat_request_messages.append(Content(role="function", parts=format_tool_message(message))) - else: - raise ValueError(f"Unsupported role: {message.role}") return chat_request_messages diff --git a/python/tests/unit/connectors/google/google_ai/services/test_google_ai_chat_completion.py b/python/tests/unit/connectors/google/google_ai/services/test_google_ai_chat_completion.py index 78580c90c901..0cd19185ad94 100644 --- a/python/tests/unit/connectors/google/google_ai/services/test_google_ai_chat_completion.py +++ b/python/tests/unit/connectors/google/google_ai/services/test_google_ai_chat_completion.py @@ -355,14 +355,3 @@ def test_google_ai_chat_completion_parse_chat_history_correctly(google_ai_unit_t assert parsed_chat_history[0].parts[0].text == "test_user_message" assert parsed_chat_history[1].role == "model" assert parsed_chat_history[1].parts[0].text == "test_assistant_message" - - -def test_google_ai_chat_completion_parse_chat_history_throw_unsupported_message(google_ai_unit_test_env) -> None: - """Test _prepare_chat_history_for_request method with unsupported message type""" - google_ai_chat_completion = GoogleAIChatCompletion() - - chat_history = ChatHistory() - chat_history.add_tool_message("test_tool_message") - - with pytest.raises(ValueError): - _ = google_ai_chat_completion._prepare_chat_history_for_request(chat_history) diff --git a/python/tests/unit/connectors/google/google_ai/services/test_google_ai_utils.py b/python/tests/unit/connectors/google/google_ai/services/test_google_ai_utils.py index 25619d9d4a07..1d2a6355e70a 100644 --- a/python/tests/unit/connectors/google/google_ai/services/test_google_ai_utils.py +++ b/python/tests/unit/connectors/google/google_ai/services/test_google_ai_utils.py @@ -5,10 +5,12 @@ from semantic_kernel.connectors.ai.google.google_ai.services.utils import ( finish_reason_from_google_ai_to_semantic_kernel, + format_assistant_message, format_user_message, ) from semantic_kernel.contents.chat_message_content import ChatMessageContent from semantic_kernel.contents.function_call_content import FunctionCallContent +from semantic_kernel.contents.function_result_content import FunctionResultContent from semantic_kernel.contents.image_content import ImageContent from semantic_kernel.contents.text_content import TextContent from semantic_kernel.contents.utils.author_role import AuthorRole @@ -55,13 +57,10 @@ def test_format_user_message(): def test_format_user_message_throws_with_unsupported_items() -> None: """Test format_user_message with unsupported items.""" # Test with unsupported items, any item other than TextContent and ImageContent should raise an error - # Note that method format_user_message will use the content of the message if no ImageContent is found, - # so we need to add an ImageContent to the message to trigger the error user_message = ChatMessageContent( role=AuthorRole.USER, items=[ FunctionCallContent(), - ImageContent(data="image data", mime_type="image/png"), ], ) with pytest.raises(ServiceInvalidRequestError): @@ -76,3 +75,37 @@ def test_format_user_message_throws_with_unsupported_items() -> None: ) with pytest.raises(ServiceInvalidRequestError): format_user_message(user_message) + + +def test_format_assistant_message() -> None: + assistant_message = ChatMessageContent( + role=AuthorRole.ASSISTANT, + items=[ + TextContent(text="test"), + FunctionCallContent(name="test_function", arguments={}), + ImageContent(data="image data", mime_type="image/png"), + ], + ) + + formatted_assistant_message = format_assistant_message(assistant_message) + assert isinstance(formatted_assistant_message, list) + assert len(formatted_assistant_message) == 3 + assert isinstance(formatted_assistant_message[0], Part) + assert formatted_assistant_message[0].text == "test" + assert isinstance(formatted_assistant_message[1], Part) + assert formatted_assistant_message[1].function_call.name == "test_function" + assert formatted_assistant_message[1].function_call.args == {} + assert isinstance(formatted_assistant_message[2], Part) + assert formatted_assistant_message[2].inline_data + + +def test_format_assistant_message_with_unsupported_items() -> None: + assistant_message = ChatMessageContent( + role=AuthorRole.ASSISTANT, + items=[ + FunctionResultContent(id="test_id", function_name="test_function"), + ], + ) + + with pytest.raises(ServiceInvalidRequestError): + format_assistant_message(assistant_message) diff --git a/python/tests/unit/connectors/google/test_shared_utils.py b/python/tests/unit/connectors/google/test_shared_utils.py index 5914e01fac63..599c2f9a0364 100644 --- a/python/tests/unit/connectors/google/test_shared_utils.py +++ b/python/tests/unit/connectors/google/test_shared_utils.py @@ -3,7 +3,13 @@ import pytest -from semantic_kernel.connectors.ai.google.shared_utils import filter_system_message +from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceType +from semantic_kernel.connectors.ai.google.shared_utils import ( + FUNCTION_CHOICE_TYPE_TO_GOOGLE_FUNCTION_CALLING_MODE, + GEMINI_FUNCTION_NAME_SEPARATOR, + filter_system_message, + format_gemini_function_name_to_kernel_function_fully_qualified_name, +) from semantic_kernel.contents.chat_history import ChatHistory from semantic_kernel.exceptions.service_exceptions import ServiceInvalidRequestError @@ -27,3 +33,21 @@ def test_first_system_message(): chat_history.add_system_message("System message 2") with pytest.raises(ServiceInvalidRequestError): filter_system_message(chat_history) + + +def test_function_choice_type_to_google_function_calling_mode_contain_all_types() -> None: + assert FunctionChoiceType.AUTO in FUNCTION_CHOICE_TYPE_TO_GOOGLE_FUNCTION_CALLING_MODE + assert FunctionChoiceType.NONE in FUNCTION_CHOICE_TYPE_TO_GOOGLE_FUNCTION_CALLING_MODE + assert FunctionChoiceType.REQUIRED in FUNCTION_CHOICE_TYPE_TO_GOOGLE_FUNCTION_CALLING_MODE + + +def test_format_gemini_function_name_to_kernel_function_fully_qualified_name() -> None: + # Contains the separator + gemini_function_name = f"plugin{GEMINI_FUNCTION_NAME_SEPARATOR}function" + assert ( + format_gemini_function_name_to_kernel_function_fully_qualified_name(gemini_function_name) == "plugin-function" + ) + + # Doesn't contain the separator + gemini_function_name = "function" + assert format_gemini_function_name_to_kernel_function_fully_qualified_name(gemini_function_name) == "function" diff --git a/python/tests/unit/connectors/google/vertex_ai/services/test_vertex_ai_chat_completion.py b/python/tests/unit/connectors/google/vertex_ai/services/test_vertex_ai_chat_completion.py index 957564fcd169..650bc64eb851 100644 --- a/python/tests/unit/connectors/google/vertex_ai/services/test_vertex_ai_chat_completion.py +++ b/python/tests/unit/connectors/google/vertex_ai/services/test_vertex_ai_chat_completion.py @@ -352,14 +352,3 @@ def test_vertex_ai_chat_completion_parse_chat_history_correctly(vertex_ai_unit_t assert parsed_chat_history[0].parts[0].text == "test_user_message" assert parsed_chat_history[1].role == "model" assert parsed_chat_history[1].parts[0].text == "test_assistant_message" - - -def test_vertex_ai_chat_completion_parse_chat_history_throw_unsupported_message(vertex_ai_unit_test_env) -> None: - """Test _prepare_chat_history_for_request method with unsupported message type""" - vertex_ai_chat_completion = VertexAIChatCompletion() - - chat_history = ChatHistory() - chat_history.add_tool_message("test_tool_message") - - with pytest.raises(ValueError): - _ = vertex_ai_chat_completion._prepare_chat_history_for_request(chat_history) diff --git a/python/tests/unit/connectors/google/vertex_ai/services/test_vertex_ai_utils.py b/python/tests/unit/connectors/google/vertex_ai/services/test_vertex_ai_utils.py index d519db2463c7..e874262e69ef 100644 --- a/python/tests/unit/connectors/google/vertex_ai/services/test_vertex_ai_utils.py +++ b/python/tests/unit/connectors/google/vertex_ai/services/test_vertex_ai_utils.py @@ -5,10 +5,12 @@ from semantic_kernel.connectors.ai.google.vertex_ai.services.utils import ( finish_reason_from_vertex_ai_to_semantic_kernel, + format_assistant_message, format_user_message, ) from semantic_kernel.contents.chat_message_content import ChatMessageContent from semantic_kernel.contents.function_call_content import FunctionCallContent +from semantic_kernel.contents.function_result_content import FunctionResultContent from semantic_kernel.contents.image_content import ImageContent from semantic_kernel.contents.text_content import TextContent from semantic_kernel.contents.utils.author_role import AuthorRole @@ -55,13 +57,10 @@ def test_format_user_message(): def test_format_user_message_throws_with_unsupported_items() -> None: """Test format_user_message with unsupported items.""" # Test with unsupported items, any item other than TextContent and ImageContent should raise an error - # Note that method format_user_message will use the content of the message if no ImageContent is found, - # so we need to add an ImageContent to the message to trigger the error user_message = ChatMessageContent( role=AuthorRole.USER, items=[ FunctionCallContent(), - ImageContent(data="image data", mime_type="image/png"), ], ) with pytest.raises(ServiceInvalidRequestError): @@ -76,3 +75,37 @@ def test_format_user_message_throws_with_unsupported_items() -> None: ) with pytest.raises(ServiceInvalidRequestError): format_user_message(user_message) + + +def test_format_assistant_message() -> None: + assistant_message = ChatMessageContent( + role=AuthorRole.ASSISTANT, + items=[ + TextContent(text="test"), + FunctionCallContent(name="test_function", arguments={}), + ImageContent(data="image data", mime_type="image/png"), + ], + ) + + formatted_assistant_message = format_assistant_message(assistant_message) + assert isinstance(formatted_assistant_message, list) + assert len(formatted_assistant_message) == 3 + assert isinstance(formatted_assistant_message[0], Part) + assert formatted_assistant_message[0].text == "test" + assert isinstance(formatted_assistant_message[1], Part) + assert formatted_assistant_message[1].function_call.name == "test_function" + assert formatted_assistant_message[1].function_call.args == {} + assert isinstance(formatted_assistant_message[2], Part) + assert formatted_assistant_message[2].inline_data + + +def test_format_assistant_message_with_unsupported_items() -> None: + assistant_message = ChatMessageContent( + role=AuthorRole.ASSISTANT, + items=[ + FunctionResultContent(id="test_id", function_name="test_function"), + ], + ) + + with pytest.raises(ServiceInvalidRequestError): + format_assistant_message(assistant_message) From e4b3fd37e2467b3dcb878efe90b7a01a4401a1a7 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Thu, 1 Aug 2024 23:34:26 -0700 Subject: [PATCH 07/13] Fix integration test --- .../azure_ai_inference_prompt_execution_settings.py | 2 ++ .../google_ai/google_ai_prompt_execution_settings.py | 2 ++ .../connectors/ai/google/google_ai/services/utils.py | 3 ++- .../connectors/ai/google/vertex_ai/services/utils.py | 3 ++- .../vertex_ai/vertex_ai_prompt_execution_settings.py | 2 ++ .../open_ai_prompt_execution_settings.py | 7 +++++-- 6 files changed, 15 insertions(+), 4 deletions(-) diff --git a/python/semantic_kernel/connectors/ai/azure_ai_inference/azure_ai_inference_prompt_execution_settings.py b/python/semantic_kernel/connectors/ai/azure_ai_inference/azure_ai_inference_prompt_execution_settings.py index 5e81ad8b76e4..5b3c7bcf261f 100644 --- a/python/semantic_kernel/connectors/ai/azure_ai_inference/azure_ai_inference_prompt_execution_settings.py +++ b/python/semantic_kernel/connectors/ai/azure_ai_inference/azure_ai_inference_prompt_execution_settings.py @@ -30,6 +30,8 @@ class AzureAIInferencePromptExecutionSettings(PromptExecutionSettings): class AzureAIInferenceChatPromptExecutionSettings(AzureAIInferencePromptExecutionSettings): """Azure AI Inference Chat Prompt Execution Settings.""" + # Do not set the tools and tool_choice manually. + # They are set by the service based on the function choice configuration. tools: list[dict[str, Any]] | None = Field(None, max_length=64) tool_choice: str | None = None diff --git a/python/semantic_kernel/connectors/ai/google/google_ai/google_ai_prompt_execution_settings.py b/python/semantic_kernel/connectors/ai/google/google_ai/google_ai_prompt_execution_settings.py index a5eff4dc2b98..2cb4895fa20a 100644 --- a/python/semantic_kernel/connectors/ai/google/google_ai/google_ai_prompt_execution_settings.py +++ b/python/semantic_kernel/connectors/ai/google/google_ai/google_ai_prompt_execution_settings.py @@ -35,6 +35,8 @@ class GoogleAITextPromptExecutionSettings(GoogleAIPromptExecutionSettings): class GoogleAIChatPromptExecutionSettings(GoogleAIPromptExecutionSettings): """Google AI Chat Prompt Execution Settings.""" + # Do not set the tools and tool_config manually. + # They are set by the service based on the function choice configuration. tools: list[dict[str, Any]] | None = Field(None, max_length=64) tool_config: dict[str, Any] | None = None diff --git a/python/semantic_kernel/connectors/ai/google/google_ai/services/utils.py b/python/semantic_kernel/connectors/ai/google/google_ai/services/utils.py index 5264358e90e0..5199384ba967 100644 --- a/python/semantic_kernel/connectors/ai/google/google_ai/services/utils.py +++ b/python/semantic_kernel/connectors/ai/google/google_ai/services/utils.py @@ -82,7 +82,8 @@ def format_assistant_message(message: ChatMessageContent) -> list[Part]: parts: list[Part] = [] for item in message.items: if isinstance(item, TextContent): - parts.append(Part(text=item.text)) + if item.text: + parts.append(Part(text=item.text)) elif isinstance(item, FunctionCallContent): parts.append( Part( diff --git a/python/semantic_kernel/connectors/ai/google/vertex_ai/services/utils.py b/python/semantic_kernel/connectors/ai/google/vertex_ai/services/utils.py index 0f3485ed8e51..9b78f67e74a0 100644 --- a/python/semantic_kernel/connectors/ai/google/vertex_ai/services/utils.py +++ b/python/semantic_kernel/connectors/ai/google/vertex_ai/services/utils.py @@ -84,7 +84,8 @@ def format_assistant_message(message: ChatMessageContent) -> list[Part]: parts: list[Part] = [] for item in message.items: if isinstance(item, TextContent): - parts.append(Part(text=item.text)) + if item.text: + parts.append(Part(text=item.text)) elif isinstance(item, FunctionCallContent): parts.append( Part( diff --git a/python/semantic_kernel/connectors/ai/google/vertex_ai/vertex_ai_prompt_execution_settings.py b/python/semantic_kernel/connectors/ai/google/vertex_ai/vertex_ai_prompt_execution_settings.py index 99389e95eb9f..062568d21fbb 100644 --- a/python/semantic_kernel/connectors/ai/google/vertex_ai/vertex_ai_prompt_execution_settings.py +++ b/python/semantic_kernel/connectors/ai/google/vertex_ai/vertex_ai_prompt_execution_settings.py @@ -36,6 +36,8 @@ class VertexAITextPromptExecutionSettings(VertexAIPromptExecutionSettings): class VertexAIChatPromptExecutionSettings(VertexAIPromptExecutionSettings): """Vertex AI Chat Prompt Execution Settings.""" + # Do not set the tools and tool_config manually. + # They are set by the service based on the function choice configuration. tools: list[Tool] | None = Field(None, max_length=64) tool_config: ToolConfig | None = None diff --git a/python/semantic_kernel/connectors/ai/open_ai/prompt_execution_settings/open_ai_prompt_execution_settings.py b/python/semantic_kernel/connectors/ai/open_ai/prompt_execution_settings/open_ai_prompt_execution_settings.py index 8cde4a8cdaa9..480d8f94342e 100644 --- a/python/semantic_kernel/connectors/ai/open_ai/prompt_execution_settings/open_ai_prompt_execution_settings.py +++ b/python/semantic_kernel/connectors/ai/open_ai/prompt_execution_settings/open_ai_prompt_execution_settings.py @@ -62,13 +62,16 @@ class OpenAIChatPromptExecutionSettings(OpenAIPromptExecutionSettings): """Specific settings for the Chat Completion endpoint.""" response_format: dict[Literal["type"], Literal["text", "json_object"]] | None = None - tools: list[dict[str, Any]] | None = Field(None, max_length=64) - tool_choice: str | None = None function_call: str | None = None functions: list[dict[str, Any]] | None = None messages: list[dict[str, Any]] | None = None function_call_behavior: FunctionCallBehavior | None = Field(None, exclude=True) + # Do not set the tools and tool_choice manually. + # They are set by the service based on the function choice configuration. + tools: list[dict[str, Any]] | None = Field(None, max_length=64) + tool_choice: str | None = None + @field_validator("functions", "function_call", mode="after") @classmethod def validate_function_call(cls, v: str | list[dict[str, Any]] | None = None): From e77ceffc1f38230af07cf39494f605cb86dd4c97 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Fri, 2 Aug 2024 00:24:02 -0700 Subject: [PATCH 08/13] Update README --- python/semantic_kernel/connectors/ai/google/README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/semantic_kernel/connectors/ai/google/README.md b/python/semantic_kernel/connectors/ai/google/README.md index febfe59ff94f..03f132cd518c 100644 --- a/python/semantic_kernel/connectors/ai/google/README.md +++ b/python/semantic_kernel/connectors/ai/google/README.md @@ -48,3 +48,7 @@ kernel.add_service( ``` > Alternatively, you can use an .env file to store the model id and project id. + +## Why is there code that looks almost identical in the implementations on the two connectors + +The two connectors have very similar implementations, including the utils files. However, they are fundamentally different as they depend on different packages from Google. Although the namings of many types are identical, they are different types. \ No newline at end of file From 414ee1cad100e71127c56d0e28cece506d10b8f4 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Fri, 2 Aug 2024 10:12:55 -0700 Subject: [PATCH 09/13] Add region to Vertex AI --- .../vertex_ai/services/vertex_ai_chat_completion.py | 8 ++++++-- .../vertex_ai/services/vertex_ai_text_completion.py | 7 +++++-- .../google/vertex_ai/services/vertex_ai_text_embedding.py | 5 ++++- .../connectors/ai/google/vertex_ai/vertex_ai_settings.py | 3 +++ 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/python/semantic_kernel/connectors/ai/google/vertex_ai/services/vertex_ai_chat_completion.py b/python/semantic_kernel/connectors/ai/google/vertex_ai/services/vertex_ai_chat_completion.py index 1696b7a67e4d..7ebe3770d15a 100644 --- a/python/semantic_kernel/connectors/ai/google/vertex_ai/services/vertex_ai_chat_completion.py +++ b/python/semantic_kernel/connectors/ai/google/vertex_ai/services/vertex_ai_chat_completion.py @@ -58,6 +58,7 @@ class VertexAIChatCompletion(VertexAIBase, ChatCompletionClientBase): def __init__( self, project_id: str | None = None, + region: str | None = None, gemini_model_id: str | None = None, service_id: str | None = None, env_file_path: str | None = None, @@ -69,9 +70,11 @@ def __init__( The following environment variables are used: - VERTEX_AI_GEMINI_MODEL_ID - VERTEX_AI_PROJECT_ID + - VERTEX_AI_REGION Args: project_id (str): The Google Cloud project ID. + region (str): The Google Cloud region. gemini_model_id (str): The Gemini model ID. service_id (str): The Vertex AI service ID. env_file_path (str): The path to the environment file. @@ -80,6 +83,7 @@ def __init__( try: vertex_ai_settings = VertexAISettings.create( project_id=project_id, + region=region, gemini_model_id=gemini_model_id, env_file_path=env_file_path, env_file_encoding=env_file_encoding, @@ -145,7 +149,7 @@ async def _send_chat_request( self, chat_history: ChatHistory, settings: VertexAIChatPromptExecutionSettings ) -> list[ChatMessageContent]: """Send a chat request to the Vertex AI service.""" - vertexai.init(project=self.service_settings.project_id) + vertexai.init(project=self.service_settings.project_id, location=self.service_settings.region) model = GenerativeModel( self.service_settings.gemini_model_id, system_instruction=filter_system_message(chat_history), @@ -281,7 +285,7 @@ async def _send_chat_streaming_request( settings: VertexAIChatPromptExecutionSettings, ) -> AsyncGenerator[list[StreamingChatMessageContent], Any]: """Send a streaming chat request to the Vertex AI service.""" - vertexai.init(project=self.service_settings.project_id) + vertexai.init(project=self.service_settings.project_id, location=self.service_settings.region) model = GenerativeModel( self.service_settings.gemini_model_id, system_instruction=filter_system_message(chat_history), diff --git a/python/semantic_kernel/connectors/ai/google/vertex_ai/services/vertex_ai_text_completion.py b/python/semantic_kernel/connectors/ai/google/vertex_ai/services/vertex_ai_text_completion.py index 4cc9ba8da8a8..6919b6ba521e 100644 --- a/python/semantic_kernel/connectors/ai/google/vertex_ai/services/vertex_ai_text_completion.py +++ b/python/semantic_kernel/connectors/ai/google/vertex_ai/services/vertex_ai_text_completion.py @@ -32,6 +32,7 @@ class VertexAITextCompletion(VertexAIBase, TextCompletionClientBase): def __init__( self, project_id: str | None = None, + region: str | None = None, gemini_model_id: str | None = None, service_id: str | None = None, env_file_path: str | None = None, @@ -46,6 +47,7 @@ def __init__( Args: project_id (str): The Google Cloud project ID. + region (str): The Google Cloud region. gemini_model_id (str): The Gemini model ID. service_id (str): The Vertex AI service ID. env_file_path (str): The path to the environment file. @@ -54,6 +56,7 @@ def __init__( try: vertex_ai_settings = VertexAISettings.create( project_id=project_id, + region=region, gemini_model_id=gemini_model_id, env_file_path=env_file_path, env_file_encoding=env_file_encoding, @@ -83,7 +86,7 @@ async def get_text_contents( async def _send_request(self, prompt: str, settings: VertexAITextPromptExecutionSettings) -> list[TextContent]: """Send a text generation request to the Vertex AI service.""" - vertexai.init(project=self.service_settings.project_id) + vertexai.init(project=self.service_settings.project_id, location=self.service_settings.region) model = GenerativeModel(self.service_settings.gemini_model_id) response: GenerationResponse = await model.generate_content_async( @@ -134,7 +137,7 @@ async def _send_streaming_request( self, prompt: str, settings: VertexAITextPromptExecutionSettings ) -> AsyncGenerator[list[StreamingTextContent], Any]: """Send a text generation request to the Vertex AI service.""" - vertexai.init(project=self.service_settings.project_id) + vertexai.init(project=self.service_settings.project_id, location=self.service_settings.region) model = GenerativeModel(self.service_settings.gemini_model_id) response: AsyncIterable[GenerationResponse] = await model.generate_content_async( diff --git a/python/semantic_kernel/connectors/ai/google/vertex_ai/services/vertex_ai_text_embedding.py b/python/semantic_kernel/connectors/ai/google/vertex_ai/services/vertex_ai_text_embedding.py index 71d7c649b1bf..46e59ea9bfc4 100644 --- a/python/semantic_kernel/connectors/ai/google/vertex_ai/services/vertex_ai_text_embedding.py +++ b/python/semantic_kernel/connectors/ai/google/vertex_ai/services/vertex_ai_text_embedding.py @@ -29,6 +29,7 @@ class VertexAITextEmbedding(VertexAIBase, EmbeddingGeneratorBase): def __init__( self, project_id: str | None = None, + region: str | None = None, embedding_model_id: str | None = None, service_id: str | None = None, env_file_path: str | None = None, @@ -43,6 +44,7 @@ def __init__( Args: project_id (str): The Google Cloud project ID. + region (str): The Google Cloud region. embedding_model_id (str): The Gemini model ID. service_id (str): The Vertex AI service ID. env_file_path (str): The path to the environment file. @@ -51,6 +53,7 @@ def __init__( try: vertex_ai_settings = VertexAISettings.create( project_id=project_id, + region=region, embedding_model_id=embedding_model_id, env_file_path=env_file_path, env_file_encoding=env_file_encoding, @@ -89,7 +92,7 @@ async def generate_raw_embeddings( settings = self.get_prompt_execution_settings_from_settings(settings) assert isinstance(settings, VertexAIEmbeddingPromptExecutionSettings) # nosec - vertexai.init(project=self.service_settings.project_id) + vertexai.init(project=self.service_settings.project_id, location=self.service_settings.region) model = TextEmbeddingModel.from_pretrained(self.service_settings.embedding_model_id) response: list[TextEmbedding] = await model.get_embeddings_async( texts, diff --git a/python/semantic_kernel/connectors/ai/google/vertex_ai/vertex_ai_settings.py b/python/semantic_kernel/connectors/ai/google/vertex_ai/vertex_ai_settings.py index 698d06f5ea67..66bc35035fd6 100644 --- a/python/semantic_kernel/connectors/ai/google/vertex_ai/vertex_ai_settings.py +++ b/python/semantic_kernel/connectors/ai/google/vertex_ai/vertex_ai_settings.py @@ -25,6 +25,8 @@ class VertexAISettings(KernelBaseSettings): (Env var VERTEX_AI_EMBEDDING_MODEL_ID) - project_id: str - The Google Cloud project ID. (Env var VERTEX_AI_PROJECT_ID) + - region: str - The Google Cloud region. + (Env var VERTEX_AI_REGION) """ env_prefix: ClassVar[str] = "VERTEX_AI_" @@ -32,3 +34,4 @@ class VertexAISettings(KernelBaseSettings): gemini_model_id: str | None = None embedding_model_id: str | None = None project_id: str + region: str | None = None From 23c939f4b4d5d9de9fc9410409713ee490d0e6ed Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Fri, 2 Aug 2024 10:54:30 -0700 Subject: [PATCH 10/13] Address comments --- ...azure_ai_inference_prompt_execution_settings.py | 13 +++++++++---- .../google_ai_prompt_execution_settings.py | 13 +++++++++---- .../services/google_ai_chat_completion.py | 4 ++-- .../services/vertex_ai_chat_completion.py | 4 ++-- .../vertex_ai_prompt_execution_settings.py | 13 +++++++++---- .../open_ai_prompt_execution_settings.py | 14 +++++++++----- 6 files changed, 40 insertions(+), 21 deletions(-) diff --git a/python/semantic_kernel/connectors/ai/azure_ai_inference/azure_ai_inference_prompt_execution_settings.py b/python/semantic_kernel/connectors/ai/azure_ai_inference/azure_ai_inference_prompt_execution_settings.py index 5b3c7bcf261f..9f0d8bba851d 100644 --- a/python/semantic_kernel/connectors/ai/azure_ai_inference/azure_ai_inference_prompt_execution_settings.py +++ b/python/semantic_kernel/connectors/ai/azure_ai_inference/azure_ai_inference_prompt_execution_settings.py @@ -30,10 +30,15 @@ class AzureAIInferencePromptExecutionSettings(PromptExecutionSettings): class AzureAIInferenceChatPromptExecutionSettings(AzureAIInferencePromptExecutionSettings): """Azure AI Inference Chat Prompt Execution Settings.""" - # Do not set the tools and tool_choice manually. - # They are set by the service based on the function choice configuration. - tools: list[dict[str, Any]] | None = Field(None, max_length=64) - tool_choice: str | None = None + tools: list[dict[str, Any]] | None = Field( + None, + max_length=64, + description="Do not set this manually. It is set by the service based on the function choice configuration.", + ) + tool_choice: str | None = Field( + None, + description="Do not set this manually. It is set by the service based on the function choice configuration.", + ) @experimental_class diff --git a/python/semantic_kernel/connectors/ai/google/google_ai/google_ai_prompt_execution_settings.py b/python/semantic_kernel/connectors/ai/google/google_ai/google_ai_prompt_execution_settings.py index 2cb4895fa20a..a1f0ce927e61 100644 --- a/python/semantic_kernel/connectors/ai/google/google_ai/google_ai_prompt_execution_settings.py +++ b/python/semantic_kernel/connectors/ai/google/google_ai/google_ai_prompt_execution_settings.py @@ -35,10 +35,15 @@ class GoogleAITextPromptExecutionSettings(GoogleAIPromptExecutionSettings): class GoogleAIChatPromptExecutionSettings(GoogleAIPromptExecutionSettings): """Google AI Chat Prompt Execution Settings.""" - # Do not set the tools and tool_config manually. - # They are set by the service based on the function choice configuration. - tools: list[dict[str, Any]] | None = Field(None, max_length=64) - tool_config: dict[str, Any] | None = None + tools: list[dict[str, Any]] | None = Field( + None, + max_length=64, + description="Do not set this manually. It is set by the service based on the function choice configuration.", + ) + tool_config: dict[str, Any] | None = Field( + None, + description="Do not set this manually. It is set by the service based on the function choice configuration.", + ) @override def prepare_settings_dict(self, **kwargs) -> dict[str, Any]: diff --git a/python/semantic_kernel/connectors/ai/google/google_ai/services/google_ai_chat_completion.py b/python/semantic_kernel/connectors/ai/google/google_ai/services/google_ai_chat_completion.py index c52d24f69cb7..e4129ab73d92 100644 --- a/python/semantic_kernel/connectors/ai/google/google_ai/services/google_ai_chat_completion.py +++ b/python/semantic_kernel/connectors/ai/google/google_ai/services/google_ai_chat_completion.py @@ -139,7 +139,7 @@ async def get_chat_message_contents( function_calls=function_calls, chat_history=chat_history, kernel=kernel, - arguments=kwargs.get("argument", None), + arguments=kwargs.get("arguments", None), function_call_count=fc_count, request_index=request_index, function_behavior=settings.function_choice_behavior, @@ -277,7 +277,7 @@ async def _get_streaming_chat_message_contents_auto_invoke( function_calls=function_calls, chat_history=chat_history, kernel=kernel, - arguments=kwargs.get("argument", None), + arguments=kwargs.get("arguments", None), function_call_count=len(function_calls), request_index=request_index, function_behavior=settings.function_choice_behavior, diff --git a/python/semantic_kernel/connectors/ai/google/vertex_ai/services/vertex_ai_chat_completion.py b/python/semantic_kernel/connectors/ai/google/vertex_ai/services/vertex_ai_chat_completion.py index 7ebe3770d15a..e69798b5e704 100644 --- a/python/semantic_kernel/connectors/ai/google/vertex_ai/services/vertex_ai_chat_completion.py +++ b/python/semantic_kernel/connectors/ai/google/vertex_ai/services/vertex_ai_chat_completion.py @@ -133,7 +133,7 @@ async def get_chat_message_contents( function_calls=function_calls, chat_history=chat_history, kernel=kernel, - arguments=kwargs.get("argument", None), + arguments=kwargs.get("arguments", None), function_call_count=fc_count, request_index=request_index, function_behavior=settings.function_choice_behavior, @@ -270,7 +270,7 @@ async def _get_streaming_chat_message_contents_auto_invoke( function_calls=function_calls, chat_history=chat_history, kernel=kernel, - arguments=kwargs.get("argument", None), + arguments=kwargs.get("arguments", None), function_call_count=len(function_calls), request_index=request_index, function_behavior=settings.function_choice_behavior, diff --git a/python/semantic_kernel/connectors/ai/google/vertex_ai/vertex_ai_prompt_execution_settings.py b/python/semantic_kernel/connectors/ai/google/vertex_ai/vertex_ai_prompt_execution_settings.py index 062568d21fbb..28c8eb6f28be 100644 --- a/python/semantic_kernel/connectors/ai/google/vertex_ai/vertex_ai_prompt_execution_settings.py +++ b/python/semantic_kernel/connectors/ai/google/vertex_ai/vertex_ai_prompt_execution_settings.py @@ -36,10 +36,15 @@ class VertexAITextPromptExecutionSettings(VertexAIPromptExecutionSettings): class VertexAIChatPromptExecutionSettings(VertexAIPromptExecutionSettings): """Vertex AI Chat Prompt Execution Settings.""" - # Do not set the tools and tool_config manually. - # They are set by the service based on the function choice configuration. - tools: list[Tool] | None = Field(None, max_length=64) - tool_config: ToolConfig | None = None + tools: list[Tool] | None = Field( + None, + max_length=64, + description="Do not set this manually. It is set by the service based on the function choice configuration.", + ) + tool_config: ToolConfig | None = Field( + None, + description="Do not set this manually. It is set by the service based on the function choice configuration.", + ) @override def prepare_settings_dict(self, **kwargs) -> dict[str, Any]: diff --git a/python/semantic_kernel/connectors/ai/open_ai/prompt_execution_settings/open_ai_prompt_execution_settings.py b/python/semantic_kernel/connectors/ai/open_ai/prompt_execution_settings/open_ai_prompt_execution_settings.py index 480d8f94342e..668f4ae650a4 100644 --- a/python/semantic_kernel/connectors/ai/open_ai/prompt_execution_settings/open_ai_prompt_execution_settings.py +++ b/python/semantic_kernel/connectors/ai/open_ai/prompt_execution_settings/open_ai_prompt_execution_settings.py @@ -66,11 +66,15 @@ class OpenAIChatPromptExecutionSettings(OpenAIPromptExecutionSettings): functions: list[dict[str, Any]] | None = None messages: list[dict[str, Any]] | None = None function_call_behavior: FunctionCallBehavior | None = Field(None, exclude=True) - - # Do not set the tools and tool_choice manually. - # They are set by the service based on the function choice configuration. - tools: list[dict[str, Any]] | None = Field(None, max_length=64) - tool_choice: str | None = None + tools: list[dict[str, Any]] | None = Field( + None, + max_length=64, + description="Do not set this manually. It is set by the service based on the function choice configuration.", + ) + tool_choice: str | None = Field( + None, + description="Do not set this manually. It is set by the service based on the function choice configuration.", + ) @field_validator("functions", "function_call", mode="after") @classmethod From 02edd9a9b8d0afefe3f9db8d65605fe4884a0b67 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Fri, 2 Aug 2024 13:28:21 -0700 Subject: [PATCH 11/13] Fix unit tests --- .../unit/contents/test_function_result_content.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/python/tests/unit/contents/test_function_result_content.py b/python/tests/unit/contents/test_function_result_content.py index a745c6c255ea..4b013d8a83dd 100644 --- a/python/tests/unit/contents/test_function_result_content.py +++ b/python/tests/unit/contents/test_function_result_content.py @@ -74,15 +74,11 @@ def test_from_fcc_and_result(result: Any): assert frc.metadata == {"test": "test", "test2": "test2"} -@pytest.mark.parametrize("unwrap", [True, False], ids=["unwrap", "no-unwrap"]) -def test_to_cmc(unwrap: bool): +def test_to_cmc(): frc = FunctionResultContent(id="test", name="test-function", result="test-result") - cmc = frc.to_chat_message_content(unwrap=unwrap) + cmc = frc.to_chat_message_content() assert cmc.role.value == "tool" - if unwrap: - assert cmc.items[0].text == "test-result" - else: - assert cmc.items[0].result == "test-result" + assert cmc.items[0].result == "test-result" def test_serialize(): From 84a5e9a55905e7919cee0b733f3ec7e05f519b01 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Fri, 2 Aug 2024 13:37:55 -0700 Subject: [PATCH 12/13] Fix more unit tests on env setup --- .../google_ai/services/test_google_ai_chat_completion.py | 6 ++++++ .../vertex_ai/services/test_vertex_ai_chat_completion.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/python/tests/unit/connectors/google/google_ai/services/test_google_ai_chat_completion.py b/python/tests/unit/connectors/google/google_ai/services/test_google_ai_chat_completion.py index 0cd19185ad94..7d73b882a00a 100644 --- a/python/tests/unit/connectors/google/google_ai/services/test_google_ai_chat_completion.py +++ b/python/tests/unit/connectors/google/google_ai/services/test_google_ai_chat_completion.py @@ -112,6 +112,7 @@ async def test_google_ai_chat_completion( @patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock) async def test_google_ai_chat_completion_with_function_choice_behavior_fail_verification( chat_history: ChatHistory, + google_ai_unit_test_env, ) -> None: """Test completion of GoogleAIChatCompletion with function choice behavior expect verification failure""" @@ -133,6 +134,7 @@ async def test_google_ai_chat_completion_with_function_choice_behavior_fail_veri @patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock) async def test_google_ai_chat_completion_with_function_choice_behavior( mock_google_ai_model_generate_content_async, + google_ai_unit_test_env, kernel, chat_history: ChatHistory, mock_google_ai_chat_completion_response_with_tool_call, @@ -167,6 +169,7 @@ async def test_google_ai_chat_completion_with_function_choice_behavior( @patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock) async def test_google_ai_chat_completion_with_function_choice_behavior_no_tool_call( mock_google_ai_model_generate_content_async, + google_ai_unit_test_env, kernel, chat_history: ChatHistory, mock_google_ai_chat_completion_response, @@ -240,6 +243,7 @@ async def test_google_ai_streaming_chat_completion( @patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock) async def test_google_ai_streaming_chat_completion_with_function_choice_behavior_fail_verification( chat_history: ChatHistory, + google_ai_unit_test_env, ) -> None: """Test streaming chat completion of GoogleAIChatCompletion with function choice behavior expect verification failure""" @@ -263,6 +267,7 @@ async def test_google_ai_streaming_chat_completion_with_function_choice_behavior @patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock) async def test_google_ai_streaming_chat_completion_with_function_choice_behavior( mock_google_ai_model_generate_content_async, + google_ai_unit_test_env, kernel, chat_history: ChatHistory, mock_google_ai_streaming_chat_completion_response_with_tool_call, @@ -299,6 +304,7 @@ async def test_google_ai_streaming_chat_completion_with_function_choice_behavior @patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock) async def test_google_ai_streaming_chat_completion_with_function_choice_behavior_no_tool_call( mock_google_ai_model_generate_content_async, + google_ai_unit_test_env, kernel, chat_history: ChatHistory, mock_google_ai_streaming_chat_completion_response, diff --git a/python/tests/unit/connectors/google/vertex_ai/services/test_vertex_ai_chat_completion.py b/python/tests/unit/connectors/google/vertex_ai/services/test_vertex_ai_chat_completion.py index 650bc64eb851..49d466875be1 100644 --- a/python/tests/unit/connectors/google/vertex_ai/services/test_vertex_ai_chat_completion.py +++ b/python/tests/unit/connectors/google/vertex_ai/services/test_vertex_ai_chat_completion.py @@ -111,6 +111,7 @@ async def test_vertex_ai_chat_completion( @pytest.mark.asyncio @patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock) async def test_vertex_ai_chat_completion_with_function_choice_behavior_fail_verification( + vertex_ai_unit_test_env, chat_history: ChatHistory, ) -> None: """Test completion of VertexAIChatCompletion with function choice behavior expect verification failure""" @@ -133,6 +134,7 @@ async def test_vertex_ai_chat_completion_with_function_choice_behavior_fail_veri @patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock) async def test_vertex_ai_chat_completion_with_function_choice_behavior( mock_vertex_ai_model_generate_content_async, + vertex_ai_unit_test_env, kernel, chat_history: ChatHistory, mock_vertex_ai_chat_completion_response_with_tool_call, @@ -167,6 +169,7 @@ async def test_vertex_ai_chat_completion_with_function_choice_behavior( @patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock) async def test_vertex_ai_chat_completion_with_function_choice_behavior_no_tool_call( mock_vertex_ai_model_generate_content_async, + vertex_ai_unit_test_env, kernel, chat_history: ChatHistory, mock_vertex_ai_chat_completion_response, @@ -240,6 +243,7 @@ async def test_vertex_ai_streaming_chat_completion( @patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock) async def test_vertex_ai_streaming_chat_completion_with_function_choice_behavior_fail_verification( chat_history: ChatHistory, + vertex_ai_unit_test_env, ) -> None: """Test streaming chat completion of VertexAIChatCompletion with function choice behavior expect verification failure""" @@ -263,6 +267,7 @@ async def test_vertex_ai_streaming_chat_completion_with_function_choice_behavior @patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock) async def test_vertex_ai_streaming_chat_completion_with_function_choice_behavior( mock_vertex_ai_model_generate_content_async, + vertex_ai_unit_test_env, kernel, chat_history: ChatHistory, mock_vertex_ai_streaming_chat_completion_response_with_tool_call, @@ -299,6 +304,7 @@ async def test_vertex_ai_streaming_chat_completion_with_function_choice_behavior @patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock) async def test_vertex_ai_streaming_chat_completion_with_function_choice_behavior_no_tool_call( mock_vertex_ai_model_generate_content_async, + vertex_ai_unit_test_env, kernel, chat_history: ChatHistory, mock_vertex_ai_streaming_chat_completion_response, From 71a35ec2f338c22ca4155e3e639ad467d3950dfe Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Fri, 2 Aug 2024 13:54:35 -0700 Subject: [PATCH 13/13] Remove unnecessary mocks --- .../google_ai/services/test_google_ai_chat_completion.py | 2 -- .../vertex_ai/services/test_vertex_ai_chat_completion.py | 4 +--- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/python/tests/unit/connectors/google/google_ai/services/test_google_ai_chat_completion.py b/python/tests/unit/connectors/google/google_ai/services/test_google_ai_chat_completion.py index 7d73b882a00a..bc24e98d43c0 100644 --- a/python/tests/unit/connectors/google/google_ai/services/test_google_ai_chat_completion.py +++ b/python/tests/unit/connectors/google/google_ai/services/test_google_ai_chat_completion.py @@ -109,7 +109,6 @@ async def test_google_ai_chat_completion( @pytest.mark.asyncio -@patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock) async def test_google_ai_chat_completion_with_function_choice_behavior_fail_verification( chat_history: ChatHistory, google_ai_unit_test_env, @@ -240,7 +239,6 @@ async def test_google_ai_streaming_chat_completion( @pytest.mark.asyncio -@patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock) async def test_google_ai_streaming_chat_completion_with_function_choice_behavior_fail_verification( chat_history: ChatHistory, google_ai_unit_test_env, diff --git a/python/tests/unit/connectors/google/vertex_ai/services/test_vertex_ai_chat_completion.py b/python/tests/unit/connectors/google/vertex_ai/services/test_vertex_ai_chat_completion.py index 49d466875be1..7bed2ae9e776 100644 --- a/python/tests/unit/connectors/google/vertex_ai/services/test_vertex_ai_chat_completion.py +++ b/python/tests/unit/connectors/google/vertex_ai/services/test_vertex_ai_chat_completion.py @@ -109,10 +109,9 @@ async def test_vertex_ai_chat_completion( @pytest.mark.asyncio -@patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock) async def test_vertex_ai_chat_completion_with_function_choice_behavior_fail_verification( - vertex_ai_unit_test_env, chat_history: ChatHistory, + vertex_ai_unit_test_env, ) -> None: """Test completion of VertexAIChatCompletion with function choice behavior expect verification failure""" @@ -240,7 +239,6 @@ async def test_vertex_ai_streaming_chat_completion( @pytest.mark.asyncio -@patch.object(GenerativeModel, "generate_content_async", new_callable=AsyncMock) async def test_vertex_ai_streaming_chat_completion_with_function_choice_behavior_fail_verification( chat_history: ChatHistory, vertex_ai_unit_test_env,