diff --git a/python/samples/concepts/chat_completion/openai_logit_bias.py b/python/samples/concepts/chat_completion/openai_logit_bias.py index b003aa6b2acb..6035dcc4645c 100644 --- a/python/samples/concepts/chat_completion/openai_logit_bias.py +++ b/python/samples/concepts/chat_completion/openai_logit_bias.py @@ -6,7 +6,7 @@ from semantic_kernel import Kernel from semantic_kernel.connectors.ai import PromptExecutionSettings from semantic_kernel.connectors.ai.open_ai import OpenAIChatCompletion, OpenAITextCompletion -from semantic_kernel.contents import ChatHistory +from semantic_kernel.contents import AuthorRole, ChatHistory from semantic_kernel.functions import KernelArguments from semantic_kernel.prompt_template import InputVariable, PromptTemplateConfig @@ -204,7 +204,9 @@ def _check_banned_words(banned_list, actual_list) -> bool: def _format_output(chat, banned_words) -> None: print("--- Checking for banned words ---") - chat_bot_ans_words = [word for msg in chat.messages if msg.role == "assistant" for word in msg.content.split()] + chat_bot_ans_words = [ + word for msg in chat.messages if msg.role == AuthorRole.ASSISTANT for word in msg.content.split() + ] if _check_banned_words(banned_words, chat_bot_ans_words): print("None of the banned words were found in the answer") diff --git a/python/samples/concepts/filtering/function_invocation_filters_stream.py b/python/samples/concepts/filtering/function_invocation_filters_stream.py index 62bd3d930835..17bca2cbaf24 100644 --- a/python/samples/concepts/filtering/function_invocation_filters_stream.py +++ b/python/samples/concepts/filtering/function_invocation_filters_stream.py @@ -6,6 +6,7 @@ from functools import reduce from semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion import OpenAIChatCompletion +from semantic_kernel.contents import AuthorRole from semantic_kernel.contents.chat_history import ChatHistory from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent from semantic_kernel.filters.filter_types import FilterTypes @@ -39,7 +40,7 @@ async def override_stream(stream): async for partial in stream: yield partial except Exception as e: - yield [StreamingChatMessageContent(author="assistant", content=f"Exception caught: {e}")] + yield [StreamingChatMessageContent(role=AuthorRole.ASSISTANT, content=f"Exception caught: {e}")] stream = context.result.value context.result = FunctionResult(function=context.result.function, value=override_stream(stream)) diff --git a/python/semantic_kernel/connectors/ai/ollama/services/ollama_chat_completion.py b/python/semantic_kernel/connectors/ai/ollama/services/ollama_chat_completion.py index 0c64849823ab..db35a54c58c4 100644 --- a/python/semantic_kernel/connectors/ai/ollama/services/ollama_chat_completion.py +++ b/python/semantic_kernel/connectors/ai/ollama/services/ollama_chat_completion.py @@ -12,6 +12,7 @@ from semantic_kernel.connectors.ai.ollama.ollama_prompt_execution_settings import OllamaChatPromptExecutionSettings from semantic_kernel.connectors.ai.ollama.utils import AsyncSession from semantic_kernel.connectors.ai.text_completion_client_base import TextCompletionClientBase +from semantic_kernel.contents import AuthorRole from semantic_kernel.contents.chat_history import ChatHistory from semantic_kernel.contents.chat_message_content import ChatMessageContent from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent @@ -66,7 +67,7 @@ async def get_chat_message_contents( ChatMessageContent( inner_content=response_object, ai_model_id=self.ai_model_id, - role="assistant", + role=AuthorRole.ASSISTANT, content=response_object.get("message", {"content": None}).get("content", None), ) ] @@ -105,7 +106,7 @@ async def get_streaming_chat_message_contents( break yield [ StreamingChatMessageContent( - role="assistant", + role=AuthorRole.ASSISTANT, choice_index=0, inner_content=body, ai_model_id=self.ai_model_id, @@ -131,7 +132,7 @@ async def get_text_contents( """ if not settings.ai_model_id: settings.ai_model_id = self.ai_model_id - settings.messages = [{"role": "user", "content": prompt}] + settings.messages = [{"role": AuthorRole.USER, "content": prompt}] settings.stream = False async with ( AsyncSession(self.session) as session, @@ -165,7 +166,7 @@ async def get_streaming_text_contents( """ if not settings.ai_model_id: settings.ai_model_id = self.ai_model_id - settings.messages = [{"role": "user", "content": prompt}] + settings.messages = [{"role": AuthorRole.USER, "content": prompt}] settings.stream = True async with ( AsyncSession(self.session) as session, diff --git a/python/semantic_kernel/connectors/ai/open_ai/services/open_ai_chat_completion_base.py b/python/semantic_kernel/connectors/ai/open_ai/services/open_ai_chat_completion_base.py index e6ebe70a780c..ab36599eb2df 100644 --- a/python/semantic_kernel/connectors/ai/open_ai/services/open_ai_chat_completion_base.py +++ b/python/semantic_kernel/connectors/ai/open_ai/services/open_ai_chat_completion_base.py @@ -249,12 +249,12 @@ async def get_streaming_chat_message_contents( def _chat_message_content_to_dict(self, message: "ChatMessageContent") -> dict[str, str | None]: msg = super()._chat_message_content_to_dict(message) - if message.role == "assistant": + if message.role == AuthorRole.ASSISTANT: if tool_calls := getattr(message, "tool_calls", None): msg["tool_calls"] = [tool_call.model_dump() for tool_call in tool_calls] if function_call := getattr(message, "function_call", None): msg["function_call"] = function_call.model_dump_json() - if message.role == "tool": + if message.role == AuthorRole.TOOL: if tool_call_id := getattr(message, "tool_call_id", None): msg["tool_call_id"] = tool_call_id if message.metadata and "function" in message.metadata: diff --git a/python/semantic_kernel/contents/function_result_content.py b/python/semantic_kernel/contents/function_result_content.py index e9d28461ff72..be4a4402d783 100644 --- a/python/semantic_kernel/contents/function_result_content.py +++ b/python/semantic_kernel/contents/function_result_content.py @@ -6,6 +6,7 @@ from pydantic import field_validator +from semantic_kernel.contents.author_role import AuthorRole from semantic_kernel.contents.const import FUNCTION_RESULT_CONTENT_TAG, TEXT_CONTENT_TAG from semantic_kernel.contents.kernel_content import KernelContent from semantic_kernel.contents.text_content import TextContent @@ -104,8 +105,8 @@ def to_chat_message_content(self, unwrap: bool = False) -> "ChatMessageContent": from semantic_kernel.contents.chat_message_content import ChatMessageContent if unwrap: - return ChatMessageContent(role="tool", items=[self.result]) # type: ignore - return ChatMessageContent(role="tool", items=[self]) # type: ignore + return ChatMessageContent(role=AuthorRole.TOOL, items=[self.result]) # type: ignore + return ChatMessageContent(role=AuthorRole.TOOL, items=[self]) # type: ignore def to_dict(self) -> dict[str, str]: """Convert the instance to a dictionary.""" diff --git a/python/tests/unit/connectors/open_ai/services/test_open_ai_chat_completion_base.py b/python/tests/unit/connectors/open_ai/services/test_open_ai_chat_completion_base.py index ab5d011e7b09..2e2cb8903502 100644 --- a/python/tests/unit/connectors/open_ai/services/test_open_ai_chat_completion_base.py +++ b/python/tests/unit/connectors/open_ai/services/test_open_ai_chat_completion_base.py @@ -10,7 +10,7 @@ OpenAIChatPromptExecutionSettings, ) from semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion import OpenAIChatCompletionBase -from semantic_kernel.contents import ChatMessageContent, StreamingChatMessageContent, TextContent +from semantic_kernel.contents import AuthorRole, ChatMessageContent, StreamingChatMessageContent, TextContent from semantic_kernel.contents.chat_history import ChatHistory from semantic_kernel.contents.function_call_content import FunctionCallContent from semantic_kernel.exceptions import FunctionCallInvalidArgumentsException @@ -64,7 +64,9 @@ async def test_complete_chat(tool_call, kernel: Kernel): settings.function_call_behavior = None mock_function_call = MagicMock(spec=FunctionCallContent) mock_text = MagicMock(spec=TextContent) - mock_message = ChatMessageContent(role="assistant", items=[mock_function_call] if tool_call else [mock_text]) + mock_message = ChatMessageContent( + role=AuthorRole.ASSISTANT, items=[mock_function_call] if tool_call else [mock_text] + ) mock_message_content = [mock_message] arguments = KernelArguments() diff --git a/python/tests/unit/contents/test_chat_history.py b/python/tests/unit/contents/test_chat_history.py index 33a8a1439712..89fdb1925b8e 100644 --- a/python/tests/unit/contents/test_chat_history.py +++ b/python/tests/unit/contents/test_chat_history.py @@ -213,9 +213,9 @@ def test_dump(): ) dump = chat_history.model_dump(exclude_none=True) assert dump is not None - assert dump["messages"][0]["role"] == "system" + assert dump["messages"][0]["role"] == AuthorRole.SYSTEM assert dump["messages"][0]["items"][0]["text"] == system_msg - assert dump["messages"][1]["role"] == "user" + assert dump["messages"][1]["role"] == AuthorRole.USER assert dump["messages"][1]["items"][0]["text"] == "Message" diff --git a/python/tests/unit/contents/test_chat_message_content.py b/python/tests/unit/contents/test_chat_message_content.py index a2eeec17a9fb..38ee93d8e6bb 100644 --- a/python/tests/unit/contents/test_chat_message_content.py +++ b/python/tests/unit/contents/test_chat_message_content.py @@ -12,7 +12,7 @@ def test_cmc(): - message = ChatMessageContent(role="user", content="Hello, world!") + message = ChatMessageContent(role=AuthorRole.USER, content="Hello, world!") assert message.role == AuthorRole.USER assert message.content == "Hello, world!" assert len(message.items) == 1 @@ -20,12 +20,13 @@ def test_cmc(): def test_cmc_str(): message = ChatMessageContent(role="user", content="Hello, world!") + assert message.role == AuthorRole.USER assert str(message) == "Hello, world!" def test_cmc_full(): message = ChatMessageContent( - role="user", + role=AuthorRole.USER, name="username", content="Hello, world!", inner_content="Hello, world!", @@ -42,14 +43,14 @@ def test_cmc_full(): def test_cmc_items(): - message = ChatMessageContent(role="user", items=[TextContent(text="Hello, world!")]) + message = ChatMessageContent(role=AuthorRole.USER, items=[TextContent(text="Hello, world!")]) assert message.role == AuthorRole.USER assert message.content == "Hello, world!" assert len(message.items) == 1 def test_cmc_items_and_content(): - message = ChatMessageContent(role="user", content="text", items=[TextContent(text="Hello, world!")]) + message = ChatMessageContent(role=AuthorRole.USER, content="text", items=[TextContent(text="Hello, world!")]) assert message.role == AuthorRole.USER assert message.content == "Hello, world!" assert message.items[0].text == "Hello, world!" @@ -59,7 +60,7 @@ def test_cmc_items_and_content(): def test_cmc_multiple_items(): message = ChatMessageContent( - role="system", + role=AuthorRole.SYSTEM, items=[ TextContent(text="Hello, world!"), TextContent(text="Hello, world!"), @@ -71,7 +72,7 @@ def test_cmc_multiple_items(): def test_cmc_content_set(): - message = ChatMessageContent(role="user", content="Hello, world!") + message = ChatMessageContent(role=AuthorRole.USER, content="Hello, world!") assert message.role == AuthorRole.USER assert message.content == "Hello, world!" message.content = "Hello, world to you too!" @@ -82,7 +83,7 @@ def test_cmc_content_set(): def test_cmc_content_set_empty(): - message = ChatMessageContent(role="user", content="Hello, world!") + message = ChatMessageContent(role=AuthorRole.USER, content="Hello, world!") assert message.role == AuthorRole.USER assert message.content == "Hello, world!" message.items.pop() @@ -92,7 +93,7 @@ def test_cmc_content_set_empty(): def test_cmc_to_element(): - message = ChatMessageContent(role="user", content="Hello, world!", name=None) + message = ChatMessageContent(role=AuthorRole.USER, content="Hello, world!", name=None) element = message.to_element() assert element.tag == "message" assert element.attrib == {"role": "user"} @@ -103,13 +104,13 @@ def test_cmc_to_element(): def test_cmc_to_prompt(): - message = ChatMessageContent(role="user", content="Hello, world!") + message = ChatMessageContent(role=AuthorRole.USER, content="Hello, world!") prompt = message.to_prompt() assert prompt == 'Hello, world!' def test_cmc_from_element(): - element = ChatMessageContent(role="user", content="Hello, world!").to_element() + element = ChatMessageContent(role=AuthorRole.USER, content="Hello, world!").to_element() message = ChatMessageContent.from_element(element) assert message.role == AuthorRole.USER assert message.content == "Hello, world!" @@ -182,14 +183,14 @@ def test_cmc_from_element_content_parse(xml_content, user, text_content, length) def test_cmc_serialize(): - message = ChatMessageContent(role="user", content="Hello, world!") + message = ChatMessageContent(role=AuthorRole.USER, content="Hello, world!") dumped = message.model_dump() - assert dumped["role"] == "user" + assert dumped["role"] == AuthorRole.USER assert dumped["items"][0]["text"] == "Hello, world!" def test_cmc_to_dict(): - message = ChatMessageContent(role="user", content="Hello, world!") + message = ChatMessageContent(role=AuthorRole.USER, content="Hello, world!") assert message.to_dict() == { "role": "user", "content": "Hello, world!", @@ -197,7 +198,7 @@ def test_cmc_to_dict(): def test_cmc_to_dict_keys(): - message = ChatMessageContent(role="user", content="Hello, world!") + message = ChatMessageContent(role=AuthorRole.USER, content="Hello, world!") assert message.to_dict(role_key="author", content_key="text") == { "author": "user", "text": "Hello, world!", diff --git a/python/tests/unit/contents/test_streaming_chat_message_content.py b/python/tests/unit/contents/test_streaming_chat_message_content.py index a6d13430a37a..f09f6c3408be 100644 --- a/python/tests/unit/contents/test_streaming_chat_message_content.py +++ b/python/tests/unit/contents/test_streaming_chat_message_content.py @@ -15,7 +15,7 @@ def test_scmc(): - message = StreamingChatMessageContent(choice_index=0, role="user", content="Hello, world!") + message = StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, world!") assert message.role == AuthorRole.USER assert message.content == "Hello, world!" assert len(message.items) == 1 @@ -29,7 +29,7 @@ def test_scmc_str(): def test_scmc_full(): message = StreamingChatMessageContent( choice_index=0, - role="user", + role=AuthorRole.USER, name="username", content="Hello, world!", inner_content="Hello, world!", @@ -46,7 +46,9 @@ def test_scmc_full(): def test_scmc_items(): - message = StreamingChatMessageContent(choice_index=0, role="user", items=[TextContent(text="Hello, world!")]) + message = StreamingChatMessageContent( + choice_index=0, role=AuthorRole.USER, items=[TextContent(text="Hello, world!")] + ) assert message.role == AuthorRole.USER assert message.content == "Hello, world!" assert len(message.items) == 1 @@ -54,7 +56,7 @@ def test_scmc_items(): def test_scmc_items_and_content(): message = StreamingChatMessageContent( - choice_index=0, role="user", content="text", items=[TextContent(text="Hello, world!")] + choice_index=0, role=AuthorRole.USER, content="text", items=[TextContent(text="Hello, world!")] ) assert message.role == AuthorRole.USER assert message.content == "Hello, world!" @@ -66,7 +68,7 @@ def test_scmc_items_and_content(): def test_scmc_multiple_items(): message = StreamingChatMessageContent( choice_index=0, - role="system", + role=AuthorRole.SYSTEM, items=[ TextContent(text="Hello, world!"), TextContent(text="Hello, world!"), @@ -78,7 +80,7 @@ def test_scmc_multiple_items(): def test_scmc_content_set(): - message = StreamingChatMessageContent(choice_index=0, role="user", content="Hello, world!") + message = StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, world!") assert message.role == AuthorRole.USER assert message.content == "Hello, world!" message.content = "Hello, world to you too!" @@ -89,7 +91,7 @@ def test_scmc_content_set(): def test_scmc_content_set_empty(): - message = StreamingChatMessageContent(choice_index=0, role="user", content="Hello, world!") + message = StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, world!") assert message.role == AuthorRole.USER assert message.content == "Hello, world!" message.items.pop() @@ -99,7 +101,7 @@ def test_scmc_content_set_empty(): def test_scmc_to_element(): - message = StreamingChatMessageContent(choice_index=0, role="user", content="Hello, world!", name=None) + message = StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, world!", name=None) element = message.to_element() assert element.tag == "message" assert element.attrib == {"role": "user", "choice_index": "0"} @@ -110,7 +112,7 @@ def test_scmc_to_element(): def test_scmc_to_prompt(): - message = StreamingChatMessageContent(choice_index=0, role="user", content="Hello, world!") + message = StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, world!") prompt = message.to_prompt() assert "Hello, world!" in prompt assert 'choice_index="0"' in prompt @@ -118,7 +120,7 @@ def test_scmc_to_prompt(): def test_scmc_from_element(): - element = StreamingChatMessageContent(choice_index=0, role="user", content="Hello, world!").to_element() + element = StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, world!").to_element() message = StreamingChatMessageContent.from_element(element) assert message.role == AuthorRole.USER assert message.content == "Hello, world!" @@ -188,14 +190,14 @@ def test_scmc_from_element_content_parse(xml_content, user, text_content, length def test_scmc_serialize(): - message = StreamingChatMessageContent(choice_index=0, role="user", content="Hello, world!") + message = StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, world!") dumped = message.model_dump() - assert dumped["role"] == "user" + assert dumped["role"] == AuthorRole.USER assert dumped["items"][0]["text"] == "Hello, world!" def test_scmc_to_dict(): - message = StreamingChatMessageContent(choice_index=0, role="user", content="Hello, world!") + message = StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, world!") assert message.to_dict() == { "role": "user", "content": "Hello, world!", @@ -203,7 +205,7 @@ def test_scmc_to_dict(): def test_scmc_to_dict_keys(): - message = StreamingChatMessageContent(choice_index=0, role="user", content="Hello, world!") + message = StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, world!") assert message.to_dict(role_key="author", content_key="text") == { "author": "user", "text": "Hello, world!", @@ -254,8 +256,12 @@ def test_scmc_to_dict_items(input_args, expected_dict): def test_scmc_add(): - message1 = StreamingChatMessageContent(choice_index=0, role="user", content="Hello, ", inner_content="source1") - message2 = StreamingChatMessageContent(choice_index=0, role="user", content="world!", inner_content="source2") + message1 = StreamingChatMessageContent( + choice_index=0, role=AuthorRole.USER, content="Hello, ", inner_content="source1" + ) + message2 = StreamingChatMessageContent( + choice_index=0, role=AuthorRole.USER, content="world!", inner_content="source2" + ) combined = message1 + message2 assert combined.role == AuthorRole.USER assert combined.content == "Hello, world!" @@ -264,9 +270,13 @@ def test_scmc_add(): def test_scmc_add_three(): - message1 = StreamingChatMessageContent(choice_index=0, role="user", content="Hello, ", inner_content="source1") - message2 = StreamingChatMessageContent(choice_index=0, role="user", content="world", inner_content="source2") - message3 = StreamingChatMessageContent(choice_index=0, role="user", content="!", inner_content="source3") + message1 = StreamingChatMessageContent( + choice_index=0, role=AuthorRole.USER, content="Hello, ", inner_content="source1" + ) + message2 = StreamingChatMessageContent( + choice_index=0, role=AuthorRole.USER, content="world", inner_content="source2" + ) + message3 = StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="!", inner_content="source3") combined = message1 + message2 + message3 assert combined.role == AuthorRole.USER assert combined.content == "Hello, world!" @@ -277,13 +287,13 @@ def test_scmc_add_three(): def test_scmc_add_different_items(): message1 = StreamingChatMessageContent( choice_index=0, - role="user", + role=AuthorRole.USER, items=[StreamingTextContent(choice_index=0, text="Hello, ")], inner_content="source1", ) message2 = StreamingChatMessageContent( choice_index=0, - role="user", + role=AuthorRole.USER, items=[FunctionResultContent(id="test", name="test", result="test")], inner_content="source2", ) @@ -298,24 +308,24 @@ def test_scmc_add_different_items(): "message1, message2", [ ( - StreamingChatMessageContent(choice_index=0, role="user", content="Hello, "), - StreamingChatMessageContent(choice_index=0, role="assistant", content="world!"), + StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, "), + StreamingChatMessageContent(choice_index=0, role=AuthorRole.ASSISTANT, content="world!"), ), ( - StreamingChatMessageContent(choice_index=0, role="user", content="Hello, "), - StreamingChatMessageContent(choice_index=1, role="user", content="world!"), + StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, "), + StreamingChatMessageContent(choice_index=1, role=AuthorRole.USER, content="world!"), ), ( - StreamingChatMessageContent(choice_index=0, role="user", content="Hello, ", ai_model_id="1234"), - StreamingChatMessageContent(choice_index=0, role="user", content="world!", ai_model_id="5678"), + StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, ", ai_model_id="1234"), + StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="world!", ai_model_id="5678"), ), ( - StreamingChatMessageContent(choice_index=0, role="user", content="Hello, ", encoding="utf-8"), - StreamingChatMessageContent(choice_index=0, role="user", content="world!", encoding="utf-16"), + StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, ", encoding="utf-8"), + StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="world!", encoding="utf-16"), ), ( - StreamingChatMessageContent(choice_index=0, role="user", content="Hello, "), - ChatMessageContent(role="user", content="world!"), + StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, "), + ChatMessageContent(role=AuthorRole.USER, content="world!"), ), ], ids=["different_roles", "different_index", "different_model", "different_encoding", "different_type"], @@ -326,5 +336,5 @@ def test_smsc_add_exception(message1, message2): def test_scmc_bytes(): - message = StreamingChatMessageContent(choice_index=0, role="user", content="Hello, world!") + message = StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, world!") assert bytes(message) == b"Hello, world!" diff --git a/python/tests/unit/functions/test_kernel_function_from_prompt.py b/python/tests/unit/functions/test_kernel_function_from_prompt.py index 293ea5e28741..21abc647a0df 100644 --- a/python/tests/unit/functions/test_kernel_function_from_prompt.py +++ b/python/tests/unit/functions/test_kernel_function_from_prompt.py @@ -9,6 +9,7 @@ from semantic_kernel.connectors.ai.open_ai.services.open_ai_text_completion import OpenAITextCompletion from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings from semantic_kernel.const import METADATA_EXCEPTION_KEY +from semantic_kernel.contents import AuthorRole from semantic_kernel.contents.chat_message_content import ChatMessageContent from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent from semantic_kernel.contents.text_content import TextContent @@ -161,14 +162,16 @@ async def test_invoke_chat_stream(openai_unit_test_env): with patch( "semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion.OpenAIChatCompletion.get_chat_message_contents" ) as mock: - mock.return_value = [ChatMessageContent(role="assistant", content="test", metadata={})] + mock.return_value = [ChatMessageContent(role=AuthorRole.ASSISTANT, content="test", metadata={})] result = await function.invoke(kernel=kernel) assert str(result) == "test" with patch( "semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion.OpenAIChatCompletion.get_streaming_chat_message_contents" ) as mock: - mock.return_value = [StreamingChatMessageContent(choice_index=0, role="assistant", content="test", metadata={})] + mock.return_value = [ + StreamingChatMessageContent(choice_index=0, role=AuthorRole.ASSISTANT, content="test", metadata={}) + ] async for result in function.invoke_stream(kernel=kernel): assert str(result) == "test" @@ -187,7 +190,7 @@ async def test_invoke_exception(openai_unit_test_env): "semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion.OpenAIChatCompletion.get_chat_message_contents", side_effect=Exception, ) as mock: - mock.return_value = [ChatMessageContent(role="assistant", content="test", metadata={})] + mock.return_value = [ChatMessageContent(role=AuthorRole.ASSISTANT, content="test", metadata={})] with pytest.raises(Exception, match="test"): await function.invoke(kernel=kernel) @@ -195,7 +198,9 @@ async def test_invoke_exception(openai_unit_test_env): "semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion.OpenAIChatCompletion.get_streaming_chat_message_contents", side_effect=Exception, ) as mock: - mock.return_value = [StreamingChatMessageContent(choice_index=0, role="assistant", content="test", metadata={})] + mock.return_value = [ + StreamingChatMessageContent(choice_index=0, role=AuthorRole.ASSISTANT, content="test", metadata={}) + ] with pytest.raises(Exception): async for result in function.invoke_stream(kernel=kernel): assert isinstance(result.metadata[METADATA_EXCEPTION_KEY], Exception) @@ -270,7 +275,7 @@ async def test_invoke_defaults(openai_unit_test_env): with patch( "semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion.OpenAIChatCompletion.get_chat_message_contents" ) as mock: - mock.return_value = [ChatMessageContent(role="assistant", content="test", metadata={})] + mock.return_value = [ChatMessageContent(role=AuthorRole.ASSISTANT, content="test", metadata={})] result = await function.invoke(kernel=kernel) assert str(result) == "test" @@ -313,7 +318,7 @@ async def test_create_with_multiple_settings_one_service_registered(openai_unit_ with patch( "semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion.OpenAIChatCompletion.get_chat_message_contents" ) as mock: - mock.return_value = [ChatMessageContent(role="assistant", content="test", metadata={})] + mock.return_value = [ChatMessageContent(role=AuthorRole.ASSISTANT, content="test", metadata={})] result = await function.invoke(kernel=kernel) assert str(result) == "test" diff --git a/python/tests/unit/planners/function_calling_stepwise_planner/test_function_calling_stepwise_planner.py b/python/tests/unit/planners/function_calling_stepwise_planner/test_function_calling_stepwise_planner.py index 8092815094b5..0f89df2a62eb 100644 --- a/python/tests/unit/planners/function_calling_stepwise_planner/test_function_calling_stepwise_planner.py +++ b/python/tests/unit/planners/function_calling_stepwise_planner/test_function_calling_stepwise_planner.py @@ -5,6 +5,7 @@ import pytest from semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion import OpenAIChatCompletion +from semantic_kernel.contents import AuthorRole from semantic_kernel.exceptions.planner_exceptions import PlannerInvalidConfigurationError from semantic_kernel.functions.kernel_arguments import KernelArguments from semantic_kernel.functions.kernel_function import KernelFunction @@ -60,7 +61,7 @@ async def test_build_chat_history_for_step(): "goal", "initial_plan", kernel_mock, arguments_mock, service_mock ) assert chat_history is not None - assert chat_history[0].role == "user" + assert chat_history[0].role == AuthorRole.USER @pytest.mark.asyncio diff --git a/python/tests/unit/prompt_template/test_handlebars_prompt_template.py b/python/tests/unit/prompt_template/test_handlebars_prompt_template.py index 387dd8458be8..13468f6d35a0 100644 --- a/python/tests/unit/prompt_template/test_handlebars_prompt_template.py +++ b/python/tests/unit/prompt_template/test_handlebars_prompt_template.py @@ -3,6 +3,7 @@ import pytest from pytest import mark +from semantic_kernel.contents import AuthorRole from semantic_kernel.contents.chat_history import ChatHistory from semantic_kernel.contents.chat_message_content import ChatMessageContent from semantic_kernel.contents.function_call_content import FunctionCallContent @@ -285,10 +286,12 @@ async def test_helpers_message_to_prompt(kernel: Kernel): chat_history = ChatHistory() chat_history.add_user_message("User message") chat_history.add_message( - ChatMessageContent(role="assistant", items=[FunctionCallContent(id="1", name="plug-test")]) + ChatMessageContent(role=AuthorRole.ASSISTANT, items=[FunctionCallContent(id="1", name="plug-test")]) ) chat_history.add_message( - ChatMessageContent(role="tool", items=[FunctionResultContent(id="1", name="plug-test", result="Tool message")]) + ChatMessageContent( + role=AuthorRole.TOOL, items=[FunctionResultContent(id="1", name="plug-test", result="Tool message")] + ) ) rendered = await target.render(kernel, KernelArguments(chat_history=chat_history)) diff --git a/python/tests/unit/prompt_template/test_jinja2_prompt_template.py b/python/tests/unit/prompt_template/test_jinja2_prompt_template.py index 33ed7ffdf5db..cb5a3a34e86e 100644 --- a/python/tests/unit/prompt_template/test_jinja2_prompt_template.py +++ b/python/tests/unit/prompt_template/test_jinja2_prompt_template.py @@ -3,6 +3,7 @@ import pytest from pytest import mark +from semantic_kernel.contents import AuthorRole from semantic_kernel.contents.chat_history import ChatHistory from semantic_kernel.contents.chat_message_content import ChatMessageContent from semantic_kernel.contents.function_call_content import FunctionCallContent @@ -293,10 +294,12 @@ async def test_helpers_message_to_prompt(kernel: Kernel): chat_history = ChatHistory() chat_history.add_user_message("User message") chat_history.add_message( - ChatMessageContent(role="assistant", items=[FunctionCallContent(id="1", name="plug-test")]) + ChatMessageContent(role=AuthorRole.ASSISTANT, items=[FunctionCallContent(id="1", name="plug-test")]) ) chat_history.add_message( - ChatMessageContent(role="tool", items=[FunctionResultContent(id="1", name="plug-test", result="Tool message")]) + ChatMessageContent( + role=AuthorRole.TOOL, items=[FunctionResultContent(id="1", name="plug-test", result="Tool message")] + ) ) rendered = await target.render(kernel, KernelArguments(chat_history=chat_history)) diff --git a/python/tests/unit/prompt_template/test_prompt_template_e2e.py b/python/tests/unit/prompt_template/test_prompt_template_e2e.py index 75000c462686..a145b29cbc04 100644 --- a/python/tests/unit/prompt_template/test_prompt_template_e2e.py +++ b/python/tests/unit/prompt_template/test_prompt_template_e2e.py @@ -5,6 +5,7 @@ from pytest import mark, raises from semantic_kernel import Kernel +from semantic_kernel.contents import AuthorRole from semantic_kernel.contents.chat_history import ChatHistory from semantic_kernel.exceptions import TemplateSyntaxError from semantic_kernel.functions import kernel_function @@ -360,13 +361,13 @@ async def test_renders_and_can_be_parsed(kernel: Kernel): ).render(kernel, KernelArguments(unsafe_input=unsafe_input, safe_input=safe_input)) chat_history = ChatHistory.from_rendered_prompt(result) assert chat_history - assert chat_history.messages[0].role == "system" + assert chat_history.messages[0].role == AuthorRole.SYSTEM assert chat_history.messages[0].content == "This is the system message" - assert chat_history.messages[1].role == "user" + assert chat_history.messages[1].role == AuthorRole.USER assert chat_history.messages[1].content == "This is the newer system message" - assert chat_history.messages[2].role == "user" + assert chat_history.messages[2].role == AuthorRole.USER assert chat_history.messages[2].content == "This is bold text" - assert chat_history.messages[3].role == "user" + assert chat_history.messages[3].role == AuthorRole.USER assert chat_history.messages[3].content == "This is the newest system message" @@ -397,11 +398,11 @@ async def test_renders_and_can_be_parsed_with_cdata_sections(kernel: Kernel): ) chat_history = ChatHistory.from_rendered_prompt(result) assert chat_history - assert chat_history.messages[0].role == "user" + assert chat_history.messages[0].role == AuthorRole.USER assert chat_history.messages[0].content == "This is the newer system message" - assert chat_history.messages[1].role == "user" + assert chat_history.messages[1].role == AuthorRole.USER assert chat_history.messages[1].content == "explain imagehttps://fake-link-to-image/" - assert chat_history.messages[2].role == "user" + assert chat_history.messages[2].role == AuthorRole.USER assert ( chat_history.messages[2].content == "]]>This is the newer system message