diff --git a/python/tests/unit/agents/bedrock_agent/test_bedrock_agent_channel.py b/python/tests/unit/agents/bedrock_agent/test_bedrock_agent_channel.py index 66e203d93065..a030a637393d 100644 --- a/python/tests/unit/agents/bedrock_agent/test_bedrock_agent_channel.py +++ b/python/tests/unit/agents/bedrock_agent/test_bedrock_agent_channel.py @@ -1,8 +1,28 @@ # Copyright (c) Microsoft. All rights reserved. +from collections.abc import AsyncIterable +from unittest.mock import MagicMock + import pytest +from semantic_kernel.agents.agent import Agent +from semantic_kernel.agents.bedrock.models.bedrock_agent_model import BedrockAgentModel +from semantic_kernel.agents.channels.bedrock_agent_channel import BedrockAgentChannel from semantic_kernel.contents.chat_message_content import ChatMessageContent +from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent +from semantic_kernel.contents.utils.author_role import AuthorRole +from semantic_kernel.exceptions.agent_exceptions import AgentChatException + + +class ConcreteAgent(Agent): + async def get_response(self, *args, **kwargs) -> ChatMessageContent: + raise NotImplementedError + + def invoke(self, *args, **kwargs) -> AsyncIterable[ChatMessageContent]: + raise NotImplementedError + + def invoke_stream(self, *args, **kwargs) -> AsyncIterable[StreamingChatMessageContent]: + raise NotImplementedError @pytest.fixture @@ -32,6 +52,23 @@ def chat_history_not_alternate_role() -> list[ChatMessageContent]: ] +@pytest.fixture +def mock_agent(): + """ + Fixture that creates a mock BedrockAgent. + """ + from semantic_kernel.agents.bedrock.bedrock_agent import BedrockAgent + + # Create mocks + mock_agent = MagicMock(spec=BedrockAgent) + # Set the name and agent_model properties + mock_agent.name = "MockBedrockAgent" + mock_agent.agent_model = MagicMock(spec=BedrockAgentModel) + mock_agent.agent_model.foundation_model = "mock-foundation-model" + + return mock_agent + + async def test_receive_message(mock_channel, chat_history): # Test to verify the receive_message functionality await mock_channel.receive(chat_history) @@ -59,5 +96,221 @@ async def test_channel_reset(mock_channel, chat_history): # Test to verify the reset functionality await mock_channel.receive(chat_history) assert len(mock_channel) == len(chat_history) + assert len(mock_channel.messages) == len(chat_history) await mock_channel.reset() assert len(mock_channel) == 0 + assert len(mock_channel.messages) == 0 + + +async def test_receive_appends_history_correctly(mock_channel): + """Test that the receive method appends messages while ensuring they alternate in author role.""" + # Provide a list of messages with identical roles to see if placeholders are inserted + incoming_messages = [ + ChatMessageContent(role=AuthorRole.USER, content="User message 1"), + ChatMessageContent(role=AuthorRole.USER, content="User message 2"), + ChatMessageContent(role=AuthorRole.ASSISTANT, content="Assistant message 1"), + ChatMessageContent(role=AuthorRole.ASSISTANT, content="Assistant message 2"), + ] + + await mock_channel.receive(incoming_messages) + + # The final channel.messages should be: + # user message 1, user placeholder, user message 2, assistant placeholder, assistant message 1, + # assistant placeholder, assistant message 2 + expected_roles = [ + AuthorRole.USER, + AuthorRole.ASSISTANT, # placeholder + AuthorRole.USER, + AuthorRole.ASSISTANT, + AuthorRole.USER, # placeholder + AuthorRole.ASSISTANT, + ] + expected_contents = [ + "User message 1", + BedrockAgentChannel.MESSAGE_PLACEHOLDER, + "User message 2", + "Assistant message 1", + BedrockAgentChannel.MESSAGE_PLACEHOLDER, + "Assistant message 2", + ] + + assert len(mock_channel.messages) == len(expected_roles) + for i, (msg, exp_role, exp_content) in enumerate(zip(mock_channel.messages, expected_roles, expected_contents)): + assert msg.role == exp_role, f"Role mismatch at index {i}" + assert msg.content == exp_content, f"Content mismatch at index {i}" + + +async def test_invoke_raises_exception_for_non_bedrock_agent(mock_channel): + """Test invoke method raises AgentChatException if the agent provided is not a BedrockAgent.""" + # Place a message in the channel so it's not empty + mock_channel.messages.append(ChatMessageContent(role=AuthorRole.USER, content="User message")) + + # Create a dummy agent that is not BedrockAgent + non_bedrock_agent = ConcreteAgent() + + with pytest.raises(AgentChatException) as exc_info: + _ = [msg async for msg in mock_channel.invoke(non_bedrock_agent)] + + assert "Agent is not of the expected type" in str(exc_info.value) + + +async def test_invoke_raises_exception_if_no_history(mock_channel, mock_agent): + """Test invoke method raises AgentChatException if no chat history is available.""" + with pytest.raises(AgentChatException) as exc_info: + _ = [msg async for msg in mock_channel.invoke(mock_agent)] + + assert "No chat history available" in str(exc_info.value) + + +async def test_invoke_inserts_placeholders_when_history_needs_to_alternate(mock_channel, mock_agent): + """Test invoke ensures _ensure_history_alternates and _ensure_last_message_is_user are called.""" + # Put messages in the channel such that the last message is an assistant's + mock_channel.messages.append(ChatMessageContent(role=AuthorRole.ASSISTANT, content="Assistant 1")) + + # Mock agent.invoke to return an async generator + async def mock_invoke(session_id: str, input_text: str, sessionState=None, **kwargs): + # We just yield one message as if the agent responded + yield ChatMessageContent(role=AuthorRole.ASSISTANT, content="Mock Agent Response") + + mock_agent.invoke = mock_invoke + + # Because the last message is from the assistant, we expect a placeholder user message to be appended + # also the history might need to alternate. + # But since there's only one message, there's nothing to fix except the last message is user. + + # We will now add a user message so we do not get the "No chat history available" error + mock_channel.messages.append(ChatMessageContent(role=AuthorRole.USER, content="User 1")) + + # Now we do invoke + outputs = [msg async for msg in mock_channel.invoke(mock_agent)] + + # We'll check if the response is appended to channel.messages + assert len(outputs) == 1 + assert outputs[0][0] is True, "Expected a user-facing response" + agent_response = outputs[0][1] + assert agent_response.content == "Mock Agent Response" + + # The channel messages should now have 3 messages: the assistant, the user, and the new agent message + assert len(mock_channel.messages) == 3 + assert mock_channel.messages[-1].role == AuthorRole.ASSISTANT + assert mock_channel.messages[-1].content == "Mock Agent Response" + + +async def test_invoke_stream_raises_error_for_non_bedrock_agent(mock_channel): + """Test invoke_stream raises AgentChatException if the agent provided is not a BedrockAgent.""" + mock_channel.messages.append(ChatMessageContent(role=AuthorRole.USER, content="User message")) + + non_bedrock_agent = ConcreteAgent() + + with pytest.raises(AgentChatException) as exc_info: + _ = [msg async for msg in mock_channel.invoke_stream(non_bedrock_agent, [])] + + assert "Agent is not of the expected type" in str(exc_info.value) + + +async def test_invoke_stream_raises_no_chat_history(mock_channel, mock_agent): + """Test invoke_stream raises AgentChatException if no messages in the channel.""" + + with pytest.raises(AgentChatException) as exc_info: + _ = [msg async for msg in mock_channel.invoke_stream(mock_agent, [])] + + assert "No chat history available." in str(exc_info.value) + + +async def test_invoke_stream_appends_response_message(mock_channel, mock_agent): + """Test invoke_stream properly yields streaming content and appends an aggregated message at the end.""" + # Put a user message in the channel so it won't raise No chat history + mock_channel.messages.append(ChatMessageContent(role=AuthorRole.USER, content="Last user message")) + + async def mock_invoke_stream( + session_id: str, input_text: str, sessionState=None, **kwargs + ) -> AsyncIterable[StreamingChatMessageContent]: + yield StreamingChatMessageContent( + role=AuthorRole.ASSISTANT, + choice_index=0, + content="Hello", + ) + yield StreamingChatMessageContent( + role=AuthorRole.ASSISTANT, + choice_index=0, + content=" World", + ) + + mock_agent.invoke_stream = mock_invoke_stream + + # Check that we get the streamed messages and that the summarized message is appended afterward + messages_param = [ChatMessageContent(role=AuthorRole.USER, content="Last user message")] # just to pass the param + streamed_content = [msg async for msg in mock_channel.invoke_stream(mock_agent, messages_param)] + + # We expect two streamed chunks: "Hello" and " World" + assert len(streamed_content) == 2 + assert streamed_content[0].content == "Hello" + assert streamed_content[1].content == " World" + + # Then we expect the channel to append an aggregated ChatMessageContent with "Hello World" + assert len(messages_param) == 2 + appended = messages_param[1] + assert appended.role == AuthorRole.ASSISTANT + assert appended.content == "Hello World" + + +async def test_get_history(mock_channel, chat_history): + """Test get_history yields messages in reverse order.""" + mock_channel.messages = chat_history + + reversed_history = [msg async for msg in mock_channel.get_history()] + + # Should be reversed + assert reversed_history[0].content == "I'm good, thank you!" + assert reversed_history[1].content == "How are you?" + assert reversed_history[2].content == "Hello, User!" + assert reversed_history[3].content == "Hello, Bedrock!" + + +async def test_invoke_alternates_history_and_ensures_last_user_message(mock_channel, mock_agent): + """Test invoke method ensures history alternates and last message is user.""" + mock_channel.messages = [ + ChatMessageContent(role=AuthorRole.USER, content="User1"), + ChatMessageContent(role=AuthorRole.USER, content="User2"), + ChatMessageContent(role=AuthorRole.ASSISTANT, content="Assist1"), + ChatMessageContent(role=AuthorRole.ASSISTANT, content="Assist2"), + ChatMessageContent(role=AuthorRole.USER, content="User3"), + ChatMessageContent(role=AuthorRole.USER, content="User4"), + ChatMessageContent(role=AuthorRole.ASSISTANT, content="Assist3"), + ] + + async for _, msg in mock_channel.invoke(mock_agent): + pass + + # let's define expected roles from that final structure: + expected_roles = [ + AuthorRole.USER, + AuthorRole.ASSISTANT, # placeholder + AuthorRole.USER, + AuthorRole.ASSISTANT, + AuthorRole.USER, # placeholder + AuthorRole.ASSISTANT, + AuthorRole.USER, + AuthorRole.ASSISTANT, # placeholder + AuthorRole.USER, + AuthorRole.ASSISTANT, + AuthorRole.USER, # placeholder + ] + expected_contents = [ + "User1", + BedrockAgentChannel.MESSAGE_PLACEHOLDER, + "User2", + "Assist1", + BedrockAgentChannel.MESSAGE_PLACEHOLDER, + "Assist2", + "User3", + BedrockAgentChannel.MESSAGE_PLACEHOLDER, + "User4", + "Assist3", + BedrockAgentChannel.MESSAGE_PLACEHOLDER, + ] + + assert len(mock_channel.messages) == len(expected_roles) + for i, (msg, exp_role, exp_content) in enumerate(zip(mock_channel.messages, expected_roles, expected_contents)): + assert msg.role == exp_role, f"Role mismatch at index {i}. Got {msg.role}, expected {exp_role}" + assert msg.content == exp_content, f"Content mismatch at index {i}. Got {msg.content}, expected {exp_content}" diff --git a/python/tests/unit/connectors/ai/bedrock/services/test_bedrock_model_provider_utils.py b/python/tests/unit/connectors/ai/bedrock/services/test_bedrock_model_provider_utils.py index 4a5728be554c..ef279db2521d 100644 --- a/python/tests/unit/connectors/ai/bedrock/services/test_bedrock_model_provider_utils.py +++ b/python/tests/unit/connectors/ai/bedrock/services/test_bedrock_model_provider_utils.py @@ -1,5 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. +from unittest.mock import MagicMock + import pytest from semantic_kernel.connectors.ai.bedrock.bedrock_prompt_execution_settings import BedrockChatPromptExecutionSettings @@ -7,10 +9,22 @@ BedrockModelProvider, ) from semantic_kernel.connectors.ai.bedrock.services.model_provider.utils import ( + MESSAGE_CONVERTERS, + finish_reason_from_bedrock_to_semantic_kernel, remove_none_recursively, update_settings_from_function_choice_configuration, ) +from semantic_kernel.connectors.ai.function_call_choice_configuration import FunctionCallChoiceConfiguration from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior +from semantic_kernel.connectors.ai.function_choice_type import FunctionChoiceType +from semantic_kernel.contents.chat_message_content import ChatMessageContent +from semantic_kernel.contents.function_call_content import FunctionCallContent +from semantic_kernel.contents.function_result_content import FunctionResultContent +from semantic_kernel.contents.image_content import ImageContent +from semantic_kernel.contents.text_content import TextContent +from semantic_kernel.contents.utils.author_role import AuthorRole +from semantic_kernel.contents.utils.finish_reason import FinishReason +from semantic_kernel.exceptions.service_exceptions import ServiceInvalidRequestError from semantic_kernel.kernel import Kernel @@ -146,3 +160,284 @@ def test_inference_profile_with_bedrock_model() -> None: unknown_inference_profile = "unknown" with pytest.raises(ValueError, match="Model ID unknown does not contain a valid model provider name."): BedrockModelProvider.to_model_provider(unknown_inference_profile) + + +def test_remove_none_recursively_empty_dict() -> None: + """Test that an empty dict returns an empty dict.""" + assert remove_none_recursively({}) == {} + + +def test_remove_none_recursively_no_none() -> None: + """Test that a dict with no None values remains the same.""" + original = {"a": 1, "b": 2} + result = remove_none_recursively(original) + assert result == {"a": 1, "b": 2} + + +def test_remove_none_recursively_with_none() -> None: + """Test that dict values of None are removed.""" + original = {"a": 1, "b": None, "c": {"d": None, "e": 3}} + result = remove_none_recursively(original) + # 'b' should be removed and 'd' inside nested dict should be removed + assert result == {"a": 1, "c": {"e": 3}} + + +def test_remove_none_recursively_max_depth() -> None: + """Test that the function respects max_depth.""" + original = {"a": {"b": {"c": None}}} + # If max_depth=1, it won't go deep enough to remove 'c'. + result = remove_none_recursively(original, max_depth=1) + assert result == {"a": {"b": {"c": None}}} + + # If max_depth=3, it should remove 'c'. + result = remove_none_recursively(original, max_depth=3) + assert result == {"a": {"b": {}}} + + +def test_format_system_message() -> None: + """Test that system message is formatted correctly.""" + content = ChatMessageContent(role=AuthorRole.SYSTEM, content="System message") + formatted = MESSAGE_CONVERTERS[AuthorRole.SYSTEM](content) + assert formatted == {"text": "System message"} + + +def test_format_user_message_text_only() -> None: + """Test user message with only text content.""" + text_item = TextContent(text="Hello!") + user_message = ChatMessageContent(role=AuthorRole.USER, items=[text_item]) + + formatted = MESSAGE_CONVERTERS[AuthorRole.USER](user_message) + assert formatted["role"] == "user" + assert len(formatted["content"]) == 1 + assert formatted["content"][0] == {"text": "Hello!"} + + +def test_format_user_message_image_only() -> None: + """Test user message with only image content.""" + img_item = ImageContent(data=b"abc", mime_type="image/png") + user_message = ChatMessageContent(role=AuthorRole.USER, items=[img_item]) + + formatted = MESSAGE_CONVERTERS[AuthorRole.USER](user_message) + assert formatted["role"] == "user" + assert len(formatted["content"]) == 1 + image_section = formatted["content"][0].get("image") + assert image_section["format"] == "png" + assert image_section["source"]["bytes"] == b"abc" + + +def test_format_user_message_unsupported_content() -> None: + """Test user message raises error with unsupported content type.""" + # We can simulate an unsupported content type by using FunctionCallContent. + func_call_item = FunctionCallContent(id="123", function_name="test_function", arguments="{}") + user_message = ChatMessageContent(role=AuthorRole.USER, items=[func_call_item]) + + with pytest.raises(ServiceInvalidRequestError) as exc: + MESSAGE_CONVERTERS[AuthorRole.USER](user_message) + + assert "Only text and image content are supported in a user message." in str(exc.value) + + +def test_format_assistant_message_text_content() -> None: + """Test assistant message with text content.""" + text_item = TextContent(text="Assistant response") + assistant_message = ChatMessageContent(role=AuthorRole.ASSISTANT, items=[text_item]) + + formatted = MESSAGE_CONVERTERS[AuthorRole.ASSISTANT](assistant_message) + assert formatted["role"] == "assistant" + assert formatted["content"] == [{"text": "Assistant response"}] + + +def test_format_assistant_message_function_call_content() -> None: + """Test assistant message with function call content.""" + func_item = FunctionCallContent( + id="fc1", plugin_name="plugin", function_name="function", arguments='{"param": "value"}' + ) + assistant_message = ChatMessageContent(role=AuthorRole.ASSISTANT, items=[func_item]) + + formatted = MESSAGE_CONVERTERS[AuthorRole.ASSISTANT](assistant_message) + assert len(formatted["content"]) == 1 + tool_use = formatted["content"][0].get("toolUse") + assert tool_use + assert tool_use["toolUseId"] == "fc1" + assert tool_use["name"] == "plugin-function" + assert tool_use["input"] == {"param": "value"} + + +def test_format_assistant_message_image_content_raises() -> None: + """Test assistant message with image raises error.""" + img_item = ImageContent(data=b"abc", mime_type="image/jpeg") + assistant_message = ChatMessageContent(role=AuthorRole.ASSISTANT, items=[img_item]) + + with pytest.raises(ServiceInvalidRequestError) as exc: + MESSAGE_CONVERTERS[AuthorRole.ASSISTANT](assistant_message) + + assert "Image content is not supported in an assistant message." in str(exc.value) + + +def test_format_assistant_message_unsupported_type() -> None: + """Test assistant message with unsupported item content type.""" + func_res_item = FunctionResultContent(id="res1", function_name="some_function", result="some_result") + assistant_message = ChatMessageContent(role=AuthorRole.ASSISTANT, items=[func_res_item]) + + with pytest.raises(ServiceInvalidRequestError) as exc: + MESSAGE_CONVERTERS[AuthorRole.ASSISTANT](assistant_message) + assert "Unsupported content type in an assistant message:" in str(exc.value) + + +def test_format_tool_message_text() -> None: + """Test tool message with text content.""" + text_item = TextContent(text="Some text") + tool_message = ChatMessageContent(role=AuthorRole.TOOL, items=[text_item]) + + formatted = MESSAGE_CONVERTERS[AuthorRole.TOOL](tool_message) + assert formatted["role"] == "user" # note that for a tool message, role set to 'user' + assert formatted["content"] == [{"text": "Some text"}] + + +def test_format_tool_message_function_result() -> None: + """Test tool message with function result content.""" + func_result_item = FunctionResultContent(id="res_id", function_name="test_function", result="some result") + tool_message = ChatMessageContent(role=AuthorRole.TOOL, items=[func_result_item]) + + formatted = MESSAGE_CONVERTERS[AuthorRole.TOOL](tool_message) + assert formatted["role"] == "user" + content = formatted["content"][0] + assert content.get("toolResult") + assert content["toolResult"]["toolUseId"] == "res_id" + assert content["toolResult"]["content"] == [{"text": "some result"}] + + +def test_format_tool_message_image_raises() -> None: + """Test tool message with image content raises an error.""" + img_item = ImageContent(data=b"xyz", mime_type="image/jpeg") + tool_message = ChatMessageContent(role=AuthorRole.TOOL, items=[img_item]) + + with pytest.raises(ServiceInvalidRequestError) as exc: + MESSAGE_CONVERTERS[AuthorRole.TOOL](tool_message) + assert "Image content is not supported in a tool message." in str(exc.value) + + +def test_finish_reason_from_bedrock_to_semantic_kernel_stop() -> None: + """Test that 'stop_sequence' maps to FinishReason.STOP""" + reason = finish_reason_from_bedrock_to_semantic_kernel("stop_sequence") + assert reason == FinishReason.STOP + + reason = finish_reason_from_bedrock_to_semantic_kernel("end_turn") + assert reason == FinishReason.STOP + + +def test_finish_reason_from_bedrock_to_semantic_kernel_length() -> None: + """Test that 'max_tokens' maps to FinishReason.LENGTH""" + reason = finish_reason_from_bedrock_to_semantic_kernel("max_tokens") + assert reason == FinishReason.LENGTH + + +def test_finish_reason_from_bedrock_to_semantic_kernel_content_filtered() -> None: + """Test that 'content_filtered' maps to FinishReason.CONTENT_FILTER""" + reason = finish_reason_from_bedrock_to_semantic_kernel("content_filtered") + assert reason == FinishReason.CONTENT_FILTER + + +def test_finish_reason_from_bedrock_to_semantic_kernel_tool_use() -> None: + """Test that 'tool_use' maps to FinishReason.TOOL_CALLS""" + reason = finish_reason_from_bedrock_to_semantic_kernel("tool_use") + assert reason == FinishReason.TOOL_CALLS + + +def test_finish_reason_from_bedrock_to_semantic_kernel_unknown() -> None: + """Test that unknown finish reason returns None""" + reason = finish_reason_from_bedrock_to_semantic_kernel("something_unknown") + assert reason is None + + +@pytest.fixture +def mock_bedrock_settings() -> BedrockChatPromptExecutionSettings: + """Helper fixture for BedrockChatPromptExecutionSettings.""" + return BedrockChatPromptExecutionSettings() + + +@pytest.fixture +def mock_function_choice_config() -> FunctionCallChoiceConfiguration: + """Helper fixture for a sample FunctionCallChoiceConfiguration.""" + + # We'll create mock kernel functions with metadata + mock_func_1 = MagicMock() + mock_func_1.fully_qualified_name = "plugin-function1" + mock_func_1.description = "Function 1 description" + + param1 = MagicMock() + param1.name = "param1" + param1.schema_data = {"type": "string"} + param1.is_required = True + + param2 = MagicMock() + param2.name = "param2" + param2.schema_data = {"type": "integer"} + param2.is_required = False + + mock_func_1.parameters = [ + param1, + param2, + ] + mock_func_2 = MagicMock() + mock_func_2.fully_qualified_name = "plugin-function2" + mock_func_2.description = "Function 2 description" + mock_func_2.parameters = [] + + config = FunctionCallChoiceConfiguration() + config.available_functions = [mock_func_1, mock_func_2] + + return config + + +def test_update_settings_from_function_choice_configuration_none_type( + mock_function_choice_config, mock_bedrock_settings +) -> None: + """Test that if the FunctionChoiceType is NONE it doesn't modify settings.""" + update_settings_from_function_choice_configuration( + mock_function_choice_config, mock_bedrock_settings, FunctionChoiceType.NONE + ) + assert mock_bedrock_settings.tool_choice is None + assert mock_bedrock_settings.tools is None + + +def test_update_settings_from_function_choice_configuration_auto_two_tools( + mock_function_choice_config, mock_bedrock_settings +) -> None: + """Test that AUTO sets tool_choice to {"auto": {}} and sets tools list""" + update_settings_from_function_choice_configuration( + mock_function_choice_config, mock_bedrock_settings, FunctionChoiceType.AUTO + ) + assert mock_bedrock_settings.tool_choice == {"auto": {}} + assert len(mock_bedrock_settings.tools) == 2 + # Validate structure of first tool + tool_spec_1 = mock_bedrock_settings.tools[0].get("toolSpec") + assert tool_spec_1["name"] == "plugin-function1" + assert tool_spec_1["description"] == "Function 1 description" + + +def test_update_settings_from_function_choice_configuration_required_many( + mock_function_choice_config, mock_bedrock_settings +) -> None: + """Test that REQUIRED with more than one function sets tool_choice to {"any": {}}.""" + update_settings_from_function_choice_configuration( + mock_function_choice_config, mock_bedrock_settings, FunctionChoiceType.REQUIRED + ) + assert mock_bedrock_settings.tool_choice == {"any": {}} + assert len(mock_bedrock_settings.tools) == 2 + + +def test_update_settings_from_function_choice_configuration_required_one(mock_bedrock_settings) -> None: + """Test that REQUIRED with a single function picks "tool" with that function name.""" + single_func = MagicMock() + single_func.fully_qualified_name = "plugin-function" + single_func.description = "Only function" + single_func.parameters = [] + + config = FunctionCallChoiceConfiguration() + config.available_functions = [single_func] + + update_settings_from_function_choice_configuration(config, mock_bedrock_settings, FunctionChoiceType.REQUIRED) + assert mock_bedrock_settings.tool_choice == {"tool": {"name": "plugin-function"}} + assert len(mock_bedrock_settings.tools) == 1 + assert mock_bedrock_settings.tools[0]["toolSpec"]["name"] == "plugin-function" diff --git a/python/tests/unit/connectors/ai/mistral_ai/services/test_mistralai_chat_completion.py b/python/tests/unit/connectors/ai/mistral_ai/services/test_mistralai_chat_completion.py index 8db675b627bc..b8b8c25bd951 100644 --- a/python/tests/unit/connectors/ai/mistral_ai/services/test_mistralai_chat_completion.py +++ b/python/tests/unit/connectors/ai/mistral_ai/services/test_mistralai_chat_completion.py @@ -1,11 +1,22 @@ # Copyright (c) Microsoft. All rights reserved. + from unittest.mock import AsyncMock, MagicMock, patch import pytest -from mistralai import Mistral +from mistralai import CompletionEvent, Mistral +from mistralai.models import ( + AssistantMessage, + ChatCompletionChoice, + ChatCompletionResponse, + CompletionChunk, + CompletionResponseStreamChoice, + DeltaMessage, + UsageInfo, +) from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase -from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior +from semantic_kernel.connectors.ai.function_call_choice_configuration import FunctionCallChoiceConfiguration +from semantic_kernel.connectors.ai.function_choice_behavior import FunctionChoiceBehavior, FunctionChoiceType from semantic_kernel.connectors.ai.mistral_ai.prompt_execution_settings.mistral_ai_prompt_execution_settings import ( MistralAIChatPromptExecutionSettings, ) @@ -13,11 +24,10 @@ from semantic_kernel.connectors.ai.open_ai.prompt_execution_settings.open_ai_prompt_execution_settings import ( OpenAIChatPromptExecutionSettings, ) +from semantic_kernel.contents import FunctionCallContent, TextContent from semantic_kernel.contents.chat_history import ChatHistory from semantic_kernel.contents.chat_message_content import ( ChatMessageContent, - FunctionCallContent, - TextContent, ) from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent from semantic_kernel.contents.utils.author_role import AuthorRole @@ -27,6 +37,7 @@ ServiceResponseException, ) from semantic_kernel.functions.kernel_arguments import KernelArguments +from semantic_kernel.functions.kernel_function import KernelFunction from semantic_kernel.functions.kernel_function_decorator import kernel_function from semantic_kernel.kernel import Kernel @@ -372,3 +383,214 @@ async def test_with_different_execution_settings_stream( continue assert mock_mistral_ai_client_completion_stream.chat.stream_async.call_args.kwargs["temperature"] == 0.2 assert mock_mistral_ai_client_completion_stream.chat.stream_async.call_args.kwargs["seed"] == 2 + + +async def test_mistral_ai_chat_completion_get_chat_message_contents_success(): + """Test get_chat_message_contents with a successful ChatCompletionResponse.""" + + # Mock the response from the Mistral chat complete_async. + mock_response = ChatCompletionResponse( + id="some_id", + object="object", + created=12345, + usage=UsageInfo(prompt_tokens=10, completion_tokens=20, total_tokens=30), + model="test-model", + choices=[ + ChatCompletionChoice( + index=0, + message=AssistantMessage(role="assistant", content="Hello!"), + finish_reason="stop", + ) + ], + ) + + async_mock_client = MagicMock(spec=Mistral) + async_mock_client.chat = MagicMock() + async_mock_client.chat.complete_async = AsyncMock(return_value=mock_response) + + chat_completion = MistralAIChatCompletion( + ai_model_id="test-model", + api_key="test_key", + async_client=async_mock_client, + ) + + # We create a ChatHistory. + chat_history = ChatHistory() + settings = MistralAIChatPromptExecutionSettings() + + results = await chat_completion.get_chat_message_contents(chat_history, settings) + + # We should have exactly one ChatMessageContent. + assert len(results) == 1 + assert results[0].role.value == "assistant" + assert results[0].finish_reason is not None + assert results[0].finish_reason.value == "stop" + assert "Hello!" in results[0].content + async_mock_client.chat.complete_async.assert_awaited_once() + + +async def test_mistral_ai_chat_completion_get_chat_message_contents_failure(): + """Test get_chat_message_contents should raise ServiceResponseException if Mistral call fails.""" + async_mock_client = MagicMock(spec=Mistral) + async_mock_client.chat = MagicMock() + async_mock_client.chat.complete_async = AsyncMock(side_effect=Exception("API error")) + + chat_completion = MistralAIChatCompletion( + ai_model_id="test-model", + api_key="test_key", + async_client=async_mock_client, + ) + + chat_history = ChatHistory() + settings = MistralAIChatPromptExecutionSettings() + + with pytest.raises(ServiceResponseException) as exc: + await chat_completion.get_chat_message_contents(chat_history, settings) + assert "service failed to complete the prompt" in str(exc.value) + + +async def test_mistral_ai_chat_completion_get_streaming_chat_message_contents_success(): + """Test get_streaming_chat_message_contents when streaming successfully.""" + + # We'll yield multiple chunks to simulate streaming. + mock_chunk1 = CompletionEvent( + data=CompletionChunk( + id="chunk1", + created=1, + model="test-model", + choices=[ + CompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(role="assistant", content="Hello "), + finish_reason=None, + ) + ], + ) + ) + mock_chunk2 = CompletionEvent( + data=CompletionChunk( + id="chunk1", + created=1, + model="test-model", + choices=[ + CompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(content="World!"), + finish_reason="stop", + ) + ], + ) + ) + + async def mock_stream_async(**kwargs): + yield mock_chunk1 + yield mock_chunk2 + + async_mock_client = MagicMock(spec=Mistral) + async_mock_client.chat = MagicMock() + async_mock_client.chat.stream_async = AsyncMock(return_value=mock_stream_async()) + + chat_completion = MistralAIChatCompletion( + ai_model_id="test-model", + api_key="test_key", + async_client=async_mock_client, + ) + + chat_history = ChatHistory() + settings = MistralAIChatPromptExecutionSettings() + + collected_chunks = [] + async for chunk_list in chat_completion.get_streaming_chat_message_contents(chat_history, settings): + collected_chunks.append(chunk_list) + + # We expect two sets of chunk_list yields. + assert len(collected_chunks) == 2 + assert len(collected_chunks[0]) == 1 + assert len(collected_chunks[1]) == 1 + + # First chunk contains "Hello ", second chunk "World!". + assert collected_chunks[0][0].items[0].text == "Hello " + assert collected_chunks[1][0].items[0].text == "World!" + + +async def test_mistral_ai_chat_completion_get_streaming_chat_message_contents_failure(): + """Test get_streaming_chat_message_contents raising a ServiceResponseException on failure.""" + async_mock_client = MagicMock(spec=Mistral) + async_mock_client.chat = MagicMock() + async_mock_client.chat.stream_async = AsyncMock(side_effect=Exception("Streaming error")) + + chat_completion = MistralAIChatCompletion( + ai_model_id="test-model", + api_key="test_key", + async_client=async_mock_client, + ) + + chat_history = ChatHistory() + settings = MistralAIChatPromptExecutionSettings() + + with pytest.raises(ServiceResponseException) as exc: + async for _ in chat_completion.get_streaming_chat_message_contents(chat_history, settings): + pass + assert "service failed to complete the prompt" in str(exc.value) + + +async def test_mistral_ai_chat_completion_update_settings_from_function_call_configuration_mistral(): + """Test update_settings_from_function_call_configuration_mistral sets tools etc.""" + + chat_completion = MistralAIChatCompletion( + ai_model_id="test-model", + api_key="test_key", + ) + + # Create a mock settings object. + settings = MistralAIChatPromptExecutionSettings() + # Create a function choice config with some available functions. + config = FunctionCallChoiceConfiguration() + mock_func = MagicMock( + spec=KernelFunction, + ) + mock_func.name = "my_func" + mock_func.description = "some desc" + mock_func.fully_qualified_name = "mod.my_func" + mock_func.parameters = [] + config.available_functions = [mock_func] + + # Call the update_settings_from_function_call_configuration_mistral with type=ANY. + chat_completion.update_settings_from_function_call_configuration_mistral( + function_choice_configuration=config, + settings=settings, + type=FunctionChoiceType.AUTO, + ) + + assert settings.tool_choice == FunctionChoiceType.AUTO.value + assert settings.tools is not None + assert len(settings.tools) == 1 + assert settings.tools[0]["function"]["name"] == "mod.my_func" + + +async def test_mistral_ai_chat_completion_reset_function_choice_settings(): + """Test that _reset_function_choice_settings resets specific attributes.""" + chat_completion = MistralAIChatCompletion( + ai_model_id="test-model", + api_key="test_key", + ) + settings = MistralAIChatPromptExecutionSettings(tool_choice="any", tools=[{"name": "func1"}]) + + chat_completion._reset_function_choice_settings(settings) + assert settings.tool_choice is None + assert settings.tools is None + + +async def test_mistral_ai_chat_completion_service_url(): + """Test that service_url attempts to use _endpoint from the async_client.""" + async_mock_client = MagicMock(spec=Mistral) + async_mock_client._endpoint = "mistral" + + chat_completion = MistralAIChatCompletion( + ai_model_id="test-model", + api_key="test_key", + async_client=async_mock_client, + ) + + url = chat_completion.service_url() + assert url == "mistral" diff --git a/python/tests/unit/connectors/search/bing/test_bing_search.py b/python/tests/unit/connectors/search/bing/test_bing_search.py new file mode 100644 index 000000000000..c7a74c4a04d5 --- /dev/null +++ b/python/tests/unit/connectors/search/bing/test_bing_search.py @@ -0,0 +1,286 @@ +# Copyright (c) Microsoft. All rights reserved. + +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from semantic_kernel.connectors.search.bing.bing_search import BingSearch +from semantic_kernel.connectors.search.bing.bing_search_response import BingSearchResponse, BingWebPages +from semantic_kernel.connectors.search.bing.bing_web_page import BingWebPage +from semantic_kernel.data.kernel_search_results import KernelSearchResults +from semantic_kernel.data.text_search.text_search_filter import TextSearchFilter +from semantic_kernel.data.text_search.text_search_options import TextSearchOptions +from semantic_kernel.data.text_search.text_search_result import TextSearchResult +from semantic_kernel.exceptions import ServiceInitializationError, ServiceInvalidRequestError + + +@pytest.fixture +def bing_search(bing_unit_test_env): + """Set up the fixture to configure the Bing Search for these tests.""" + return BingSearch() + + +@pytest.fixture +def async_client_mock(): + """Set up the fixture to mock AsyncClient.""" + async_client_mock = AsyncMock() + with patch( + "semantic_kernel.connectors.search.bing.bing_search.AsyncClient.__aenter__", return_value=async_client_mock + ): + yield async_client_mock + + +@pytest.fixture +def mock_bing_search_response(): + """Set up the fixture to mock BingSearchResponse.""" + mock_web_page = BingWebPage(name="Page Name", snippet="Page Snippet", url="test") + mock_response = BingSearchResponse( + query_context={}, + webPages=MagicMock(spec=BingWebPages, value=[mock_web_page], total_estimated_matches=3), + ) + + with ( + patch.object(BingSearchResponse, "model_validate_json", return_value=mock_response), + ): + yield mock_response + + +async def test_bing_search_init_success(bing_search): + """Test that BingSearch initializes successfully with valid env.""" + # Should not raise any exception + assert bing_search.settings.api_key.get_secret_value() == "test_api_key" + assert bing_search.settings.custom_config == "test_org_id" + + +async def test_bing_search_init_validation_error(): + """Test that BingSearch raises ServiceInitializationError if BingSettings creation fails.""" + # Act / Assert + with pytest.raises(ServiceInitializationError) as exc_info: + _ = BingSearch(env_file_path="invalid.env") + assert "Failed to create Bing settings." in str(exc_info.value) + + +async def test_search_success(bing_unit_test_env, async_client_mock): + """Test that search method returns KernelSearchResults successfully on valid response.""" + # Arrange + mock_web_pages = BingWebPage(snippet="Test snippet") + mock_response = BingSearchResponse( + webPages=MagicMock(spec=BingWebPages, value=[mock_web_pages], total_estimated_matches=10), + query_context={"alteredQuery": "altered something"}, + ) + + mock_result = MagicMock() + mock_result.text = """ +{"webPages": { + "value": [{"snippet": "Test snippet"}], + "totalEstimatedMatches": 10}, + "queryContext": {"alteredQuery": "altered something"} +}""" + async_client_mock.get.return_value = mock_result + + # Act + with ( + patch.object(BingSearchResponse, "model_validate_json", return_value=mock_response), + ): + search_instance = BingSearch() + options = TextSearchOptions(include_total_count=True) + kernel_results: KernelSearchResults[str] = await search_instance.search("Test query", options) + + # Assert + results_list = [] + async for res in kernel_results.results: + results_list.append(res) + + assert len(results_list) == 1 + assert results_list[0] == "Test snippet" + assert kernel_results.total_count == 10 + assert kernel_results.metadata == {"altered_query": "altered something"} + + +async def test_search_http_status_error(bing_unit_test_env, async_client_mock): + """Test that search method raises ServiceInvalidRequestError on HTTPStatusError.""" + # Arrange + mock_response = MagicMock() + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "Error", request=MagicMock(), response=MagicMock() + ) + async_client_mock.get.return_value = mock_response + + # Act + search_instance = BingSearch() + + # Assert + with pytest.raises(ServiceInvalidRequestError) as exc_info: + await search_instance.search("Test query") + assert "Failed to get search results." in str(exc_info.value) + + +async def test_search_request_error(bing_unit_test_env, async_client_mock): + """Test that search method raises ServiceInvalidRequestError on RequestError.""" + # Arrange + async_client_mock.get.side_effect = httpx.RequestError("Client error") + + # Act + search_instance = BingSearch() + + # Assert + with pytest.raises(ServiceInvalidRequestError) as exc_info: + await search_instance.search("Test query") + assert "A client error occurred while getting search results." in str(exc_info.value) + + +async def test_search_generic_exception(bing_unit_test_env, async_client_mock): + """Test that search method raises ServiceInvalidRequestError on unexpected exception.""" + # Arrange + async_client_mock.get.side_effect = Exception("Something unexpected") + + search_instance = BingSearch() + # Assert + with pytest.raises(ServiceInvalidRequestError) as exc_info: + await search_instance.search("Test query") + assert "An unexpected error occurred while getting search results." in str(exc_info.value) + + +async def test_validate_options_raises_error_for_large_top(bing_search): + """Test that _validate_options raises ServiceInvalidRequestError when top >= 50.""" + # Arrange + options = TextSearchOptions(top=50) + + # Act / Assert + with pytest.raises(ServiceInvalidRequestError) as exc_info: + await bing_search.search("test", options) + assert "count value must be less than 50." in str(exc_info.value) + + +async def test_get_text_search_results_success(bing_unit_test_env, async_client_mock): + """Test that get_text_search_results returns KernelSearchResults[TextSearchResult].""" + # Arrange + mock_web_pages = BingWebPage(name="Result Name", snippet="Snippet", url="test") + mock_response = BingSearchResponse( + query_context={}, + webPages=MagicMock(spec=BingWebPages, value=[mock_web_pages], total_estimated_matches=5), + ) + + mock_result = MagicMock() + mock_result.text = """" +{"webPages": { + "value": [{"snippet": "Snippet", "name":"Result Name", "url":"test"}], + "totalEstimatedMatches": 5}, + "queryContext": {} +}' +""" + async_client_mock.get.return_value = mock_result + + # Act + with ( + patch.object(BingSearchResponse, "model_validate_json", return_value=mock_response), + ): + search_instance = BingSearch() + options = TextSearchOptions(include_total_count=True) + kernel_results: KernelSearchResults[TextSearchResult] = await search_instance.get_text_search_results( + "Test query", options + ) + + # Assert + results_list = [] + async for res in kernel_results.results: + results_list.append(res) + + assert len(results_list) == 1 + assert isinstance(results_list[0], TextSearchResult) + assert results_list[0].name == "Result Name" + assert results_list[0].value == "Snippet" + assert results_list[0].link == "test" + assert kernel_results.total_count == 5 + + +async def test_get_search_results_success(bing_unit_test_env, async_client_mock, mock_bing_search_response): + """Test that get_search_results returns KernelSearchResults[BingWebPage].""" + # Arrange + mock_result = MagicMock() + mock_result.text = """ +{"webPages": { + "value": [{"name": "Page Name", "snippet": "Page Snippet", "url": "test"}], + "totalEstimatedMatches": 3}, + "queryContext": {} +}""" + async_client_mock.get.return_value = mock_result + + # Act + search_instance = BingSearch() + options = TextSearchOptions(include_total_count=True) + kernel_results = await search_instance.get_search_results("Another query", options) + + # Assert + results_list = [] + async for res in kernel_results.results: + results_list.append(res) + + assert len(results_list) == 1 + assert isinstance(results_list[0], BingWebPage) + assert results_list[0].name == "Page Name" + assert results_list[0].snippet == "Page Snippet" + assert results_list[0].url == "test" + assert kernel_results.total_count == 3 + + +async def test_search_no_filter(bing_search, async_client_mock, mock_bing_search_response): + """Test that search properly sets params when no filter is provided.""" + # Arrange + options = TextSearchOptions() + + # Act + await bing_search.search("test query", options) + + # Assert + params = async_client_mock.get.call_args.kwargs["params"] + + assert params["count"] == options.top + assert params["offset"] == options.skip + + # TODO check: shouldn't this output be "test query" instead of "test query+"? + assert params["q"] == "test query+" + + +async def test_search_equal_to_filter(bing_search, async_client_mock, mock_bing_search_response): + """Test that search properly sets params with an EqualTo filter.""" + + # Arrange + my_filter = TextSearchFilter.equal_to(field_name="freshness", value="Day") + options = TextSearchOptions(filter=my_filter) + + # Act + await bing_search.search("test query", options) + + # Assert + params = async_client_mock.get.call_args.kwargs["params"] + + assert params["count"] == options.top + assert params["offset"] == options.skip + # 'freshness' is recognized in QUERY_PARAMETERS, so 'freshness' should be set + assert "freshness" in params + assert params["freshness"] == "Day" + # 'q' should be a combination of the original query plus a plus sign + assert params["q"] == "test query+".strip() + + +async def test_search_not_recognized_filter(bing_search, async_client_mock, mock_bing_search_response): + """Test that search properly appends non-recognized filters to the q parameter.""" + + # Arrange + # 'customProperty' is presumably not in QUERY_PARAMETERS + my_filter = TextSearchFilter.equal_to(field_name="customProperty", value="customValue") + options = TextSearchOptions(filter=my_filter) + + # Act + await bing_search.search("test query", options) + + # Assert + params = async_client_mock.get.call_args.kwargs["params"] + assert params["count"] == options.top + assert params["offset"] == options.skip + assert "customProperty" not in params + # We expect 'q' to contain the extra query param in a plus-joined format + assert isinstance(params["q"], str) + assert "customProperty:customValue" in params["q"] diff --git a/python/tests/unit/connectors/search/google/test_google_search.py b/python/tests/unit/connectors/search/google/test_google_search.py new file mode 100644 index 000000000000..c54e592910a9 --- /dev/null +++ b/python/tests/unit/connectors/search/google/test_google_search.py @@ -0,0 +1,170 @@ +# Copyright (c) Microsoft. All rights reserved. + +from unittest.mock import AsyncMock, patch + +import pytest +from httpx import HTTPStatusError, RequestError, Response + +from semantic_kernel.connectors.search.google.google_search import GoogleSearch +from semantic_kernel.connectors.search.google.google_search_response import ( + GoogleSearchInformation, + GoogleSearchResponse, +) +from semantic_kernel.connectors.search.google.google_search_result import GoogleSearchResult +from semantic_kernel.data.filter_clauses.any_tags_equal_to_filter_clause import AnyTagsEqualTo +from semantic_kernel.data.filter_clauses.equal_to_filter_clause import EqualTo +from semantic_kernel.data.text_search.text_search_options import TextSearchOptions +from semantic_kernel.exceptions import ( + ServiceInitializationError, + ServiceInvalidRequestError, +) + + +@pytest.fixture +def google_search(google_search_unit_test_env): + """Fixture to return a GoogleSearch instance with valid settings.""" + return GoogleSearch() + + +async def test_google_search_init_success(google_search) -> None: + """Test that GoogleSearch successfully initializes with valid parameters.""" + # Should not raise any exception + assert google_search.settings.api_key.get_secret_value() == "test_api_key" + assert google_search.settings.engine_id == "test_id" + + +async def test_google_search_init_validation_error() -> None: + """Test that GoogleSearch raises ServiceInitializationError when GoogleSearchSettings creation fails.""" + with pytest.raises(ServiceInitializationError) as exc: + GoogleSearch(env_file_path="invalid.env") + assert "Failed to create Google settings." in str(exc.value) + + +async def test_google_search_top_greater_than_10_raises_error(google_search) -> None: + """Test that passing a top value greater than 10 raises ServiceInvalidRequestError.""" + options = TextSearchOptions() + options.top = 11 # Invalid + with pytest.raises(ServiceInvalidRequestError) as exc: + await google_search.search(query="test query", options=options) + assert "count value must be less than or equal to 10." in str(exc.value) + + +async def test_google_search_no_items_in_response(google_search) -> None: + """Test that when the response has no items, search results yield nothing.""" + mock_response = GoogleSearchResponse(items=None) + + # We'll mock _inner_search to return our mock_response + with patch.object(google_search, "_inner_search", new=AsyncMock(return_value=mock_response)): + result = await google_search.search("test") + # Extract all items from the AsyncIterable + items = [item async for item in result.results] + assert len(items) == 0 + + +async def test_google_search_partial_items_in_response(google_search) -> None: + """Test that snippets are properly returned in search results.""" + snippet_1 = "Snippet 1" + snippet_2 = "Snippet 2" + item_1 = GoogleSearchResult(snippet=snippet_1) + item_2 = GoogleSearchResult(snippet=snippet_2) + mock_response = GoogleSearchResponse(items=[item_1, item_2]) + + with patch.object(google_search, "_inner_search", new=AsyncMock(return_value=mock_response)): + result = await google_search.search("test") + items = [item async for item in result.results] + assert len(items) == 2 + assert items[0] == snippet_1 + assert items[1] == snippet_2 + + +async def test_google_search_request_http_status_error(google_search) -> None: + """Test that HTTP status errors raise ServiceInvalidRequestError.""" + # Mock the AsyncClient.get call to raise HTTPStatusError + with patch( + "httpx.AsyncClient.get", + new=AsyncMock(side_effect=HTTPStatusError("Error", request=None, response=Response(status_code=400))), + ): + with pytest.raises(ServiceInvalidRequestError) as exc: + await google_search.search("query") + assert "Failed to get search results." in str(exc.value) + + +async def test_google_search_request_error(google_search) -> None: + """Test that request errors raise ServiceInvalidRequestError.""" + # Mock the AsyncClient.get call to raise RequestError + with patch("httpx.AsyncClient.get", new=AsyncMock(side_effect=RequestError("Client error"))): + with pytest.raises(ServiceInvalidRequestError) as exc: + await google_search.search("query") + assert "A client error occurred while getting search results." in str(exc.value) + + +async def test_google_search_unexpected_error(google_search) -> None: + """Test that unexpected exceptions raise ServiceInvalidRequestError.""" + # Mock the AsyncClient.get call to raise a random exception + with patch("httpx.AsyncClient.get", new=AsyncMock(side_effect=Exception("Random error"))): + with pytest.raises(ServiceInvalidRequestError) as exc: + await google_search.search("query") + assert "An unexpected error occurred while getting search results." in str(exc.value) + + +async def test_get_text_search_results(google_search) -> None: + """Test that get_text_search_results returns TextSearchResults that contain name, value, and link.""" + item_1 = GoogleSearchResult(title="Title1", snippet="Snippet1", link="Link1") + item_2 = GoogleSearchResult(title="Title2", snippet="Snippet2", link="Link2") + mock_response = GoogleSearchResponse(items=[item_1, item_2]) + + with patch.object(google_search, "_inner_search", new=AsyncMock(return_value=mock_response)): + result = await google_search.get_text_search_results("test") + items = [item async for item in result.results] + assert len(items) == 2 + assert items[0].name == "Title1" + assert items[0].value == "Snippet1" + assert items[0].link == "Link1" + assert items[1].name == "Title2" + assert items[1].value == "Snippet2" + assert items[1].link == "Link2" + + +async def test_get_search_results(google_search) -> None: + """Test that get_search_results returns GoogleSearchResult items directly.""" + item_1 = GoogleSearchResult(title="Title1", snippet="Snippet1", link="Link1") + item_2 = GoogleSearchResult(title="Title2", snippet="Snippet2", link="Link2") + mock_response = GoogleSearchResponse(items=[item_1, item_2]) + + with patch.object(google_search, "_inner_search", new=AsyncMock(return_value=mock_response)): + result = await google_search.get_search_results("test") + items = [item async for item in result.results] + assert len(items) == 2 + assert items[0].title == "Title1" + assert items[1].link == "Link2" + + +async def test_build_query_equal_to_filter(google_search) -> None: + """Test that if an EqualTo filter is recognized, it is sent along in query params.""" + filters = [ + EqualTo(field_name="lr", value="lang_en"), + AnyTagsEqualTo(field_name="tags", value="tag1"), + ] # second one is not recognized + options = TextSearchOptions() + options.filter.filters = filters + + with patch.object(google_search, "_inner_search", new=AsyncMock(return_value=GoogleSearchResponse())): + await google_search.search(query="hello world", options=options) + + +async def test_google_search_includes_total_count(google_search) -> None: + """Test that total_count is included if include_total_count is True.""" + search_info = GoogleSearchInformation( + searchTime=0.23, totalResults="42", formattedSearchTime="0.23s", formattedTotalResults="42" + ) + mock_response = GoogleSearchResponse(search_information=search_info, items=None) + + with patch.object(google_search, "_inner_search", new=AsyncMock(return_value=mock_response)): + options = TextSearchOptions() + options.include_total_count = True # not standard, so we'll set it dynamically + result = await google_search.search(query="test query", options=options) + assert result.total_count == 42 + # if we set it to false, total_count should be None + options.include_total_count = False + result_no_count = await google_search.search(query="test query", options=options) + assert result_no_count.total_count is None