diff --git a/python/samples/concepts/chat_completion/openai_logit_bias.py b/python/samples/concepts/chat_completion/openai_logit_bias.py
index b003aa6b2acb..6035dcc4645c 100644
--- a/python/samples/concepts/chat_completion/openai_logit_bias.py
+++ b/python/samples/concepts/chat_completion/openai_logit_bias.py
@@ -6,7 +6,7 @@
from semantic_kernel import Kernel
from semantic_kernel.connectors.ai import PromptExecutionSettings
from semantic_kernel.connectors.ai.open_ai import OpenAIChatCompletion, OpenAITextCompletion
-from semantic_kernel.contents import ChatHistory
+from semantic_kernel.contents import AuthorRole, ChatHistory
from semantic_kernel.functions import KernelArguments
from semantic_kernel.prompt_template import InputVariable, PromptTemplateConfig
@@ -204,7 +204,9 @@ def _check_banned_words(banned_list, actual_list) -> bool:
def _format_output(chat, banned_words) -> None:
print("--- Checking for banned words ---")
- chat_bot_ans_words = [word for msg in chat.messages if msg.role == "assistant" for word in msg.content.split()]
+ chat_bot_ans_words = [
+ word for msg in chat.messages if msg.role == AuthorRole.ASSISTANT for word in msg.content.split()
+ ]
if _check_banned_words(banned_words, chat_bot_ans_words):
print("None of the banned words were found in the answer")
diff --git a/python/samples/concepts/filtering/function_invocation_filters_stream.py b/python/samples/concepts/filtering/function_invocation_filters_stream.py
index 62bd3d930835..17bca2cbaf24 100644
--- a/python/samples/concepts/filtering/function_invocation_filters_stream.py
+++ b/python/samples/concepts/filtering/function_invocation_filters_stream.py
@@ -6,6 +6,7 @@
from functools import reduce
from semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion import OpenAIChatCompletion
+from semantic_kernel.contents import AuthorRole
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent
from semantic_kernel.filters.filter_types import FilterTypes
@@ -39,7 +40,7 @@ async def override_stream(stream):
async for partial in stream:
yield partial
except Exception as e:
- yield [StreamingChatMessageContent(author="assistant", content=f"Exception caught: {e}")]
+ yield [StreamingChatMessageContent(role=AuthorRole.ASSISTANT, content=f"Exception caught: {e}")]
stream = context.result.value
context.result = FunctionResult(function=context.result.function, value=override_stream(stream))
diff --git a/python/semantic_kernel/connectors/ai/ollama/services/ollama_chat_completion.py b/python/semantic_kernel/connectors/ai/ollama/services/ollama_chat_completion.py
index 0c64849823ab..db35a54c58c4 100644
--- a/python/semantic_kernel/connectors/ai/ollama/services/ollama_chat_completion.py
+++ b/python/semantic_kernel/connectors/ai/ollama/services/ollama_chat_completion.py
@@ -12,6 +12,7 @@
from semantic_kernel.connectors.ai.ollama.ollama_prompt_execution_settings import OllamaChatPromptExecutionSettings
from semantic_kernel.connectors.ai.ollama.utils import AsyncSession
from semantic_kernel.connectors.ai.text_completion_client_base import TextCompletionClientBase
+from semantic_kernel.contents import AuthorRole
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent
@@ -66,7 +67,7 @@ async def get_chat_message_contents(
ChatMessageContent(
inner_content=response_object,
ai_model_id=self.ai_model_id,
- role="assistant",
+ role=AuthorRole.ASSISTANT,
content=response_object.get("message", {"content": None}).get("content", None),
)
]
@@ -105,7 +106,7 @@ async def get_streaming_chat_message_contents(
break
yield [
StreamingChatMessageContent(
- role="assistant",
+ role=AuthorRole.ASSISTANT,
choice_index=0,
inner_content=body,
ai_model_id=self.ai_model_id,
@@ -131,7 +132,7 @@ async def get_text_contents(
"""
if not settings.ai_model_id:
settings.ai_model_id = self.ai_model_id
- settings.messages = [{"role": "user", "content": prompt}]
+ settings.messages = [{"role": AuthorRole.USER, "content": prompt}]
settings.stream = False
async with (
AsyncSession(self.session) as session,
@@ -165,7 +166,7 @@ async def get_streaming_text_contents(
"""
if not settings.ai_model_id:
settings.ai_model_id = self.ai_model_id
- settings.messages = [{"role": "user", "content": prompt}]
+ settings.messages = [{"role": AuthorRole.USER, "content": prompt}]
settings.stream = True
async with (
AsyncSession(self.session) as session,
diff --git a/python/semantic_kernel/connectors/ai/open_ai/services/open_ai_chat_completion_base.py b/python/semantic_kernel/connectors/ai/open_ai/services/open_ai_chat_completion_base.py
index e6ebe70a780c..ab36599eb2df 100644
--- a/python/semantic_kernel/connectors/ai/open_ai/services/open_ai_chat_completion_base.py
+++ b/python/semantic_kernel/connectors/ai/open_ai/services/open_ai_chat_completion_base.py
@@ -249,12 +249,12 @@ async def get_streaming_chat_message_contents(
def _chat_message_content_to_dict(self, message: "ChatMessageContent") -> dict[str, str | None]:
msg = super()._chat_message_content_to_dict(message)
- if message.role == "assistant":
+ if message.role == AuthorRole.ASSISTANT:
if tool_calls := getattr(message, "tool_calls", None):
msg["tool_calls"] = [tool_call.model_dump() for tool_call in tool_calls]
if function_call := getattr(message, "function_call", None):
msg["function_call"] = function_call.model_dump_json()
- if message.role == "tool":
+ if message.role == AuthorRole.TOOL:
if tool_call_id := getattr(message, "tool_call_id", None):
msg["tool_call_id"] = tool_call_id
if message.metadata and "function" in message.metadata:
diff --git a/python/semantic_kernel/contents/function_result_content.py b/python/semantic_kernel/contents/function_result_content.py
index e9d28461ff72..be4a4402d783 100644
--- a/python/semantic_kernel/contents/function_result_content.py
+++ b/python/semantic_kernel/contents/function_result_content.py
@@ -6,6 +6,7 @@
from pydantic import field_validator
+from semantic_kernel.contents.author_role import AuthorRole
from semantic_kernel.contents.const import FUNCTION_RESULT_CONTENT_TAG, TEXT_CONTENT_TAG
from semantic_kernel.contents.kernel_content import KernelContent
from semantic_kernel.contents.text_content import TextContent
@@ -104,8 +105,8 @@ def to_chat_message_content(self, unwrap: bool = False) -> "ChatMessageContent":
from semantic_kernel.contents.chat_message_content import ChatMessageContent
if unwrap:
- return ChatMessageContent(role="tool", items=[self.result]) # type: ignore
- return ChatMessageContent(role="tool", items=[self]) # type: ignore
+ return ChatMessageContent(role=AuthorRole.TOOL, items=[self.result]) # type: ignore
+ return ChatMessageContent(role=AuthorRole.TOOL, items=[self]) # type: ignore
def to_dict(self) -> dict[str, str]:
"""Convert the instance to a dictionary."""
diff --git a/python/tests/unit/connectors/open_ai/services/test_open_ai_chat_completion_base.py b/python/tests/unit/connectors/open_ai/services/test_open_ai_chat_completion_base.py
index ab5d011e7b09..2e2cb8903502 100644
--- a/python/tests/unit/connectors/open_ai/services/test_open_ai_chat_completion_base.py
+++ b/python/tests/unit/connectors/open_ai/services/test_open_ai_chat_completion_base.py
@@ -10,7 +10,7 @@
OpenAIChatPromptExecutionSettings,
)
from semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion import OpenAIChatCompletionBase
-from semantic_kernel.contents import ChatMessageContent, StreamingChatMessageContent, TextContent
+from semantic_kernel.contents import AuthorRole, ChatMessageContent, StreamingChatMessageContent, TextContent
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.function_call_content import FunctionCallContent
from semantic_kernel.exceptions import FunctionCallInvalidArgumentsException
@@ -64,7 +64,9 @@ async def test_complete_chat(tool_call, kernel: Kernel):
settings.function_call_behavior = None
mock_function_call = MagicMock(spec=FunctionCallContent)
mock_text = MagicMock(spec=TextContent)
- mock_message = ChatMessageContent(role="assistant", items=[mock_function_call] if tool_call else [mock_text])
+ mock_message = ChatMessageContent(
+ role=AuthorRole.ASSISTANT, items=[mock_function_call] if tool_call else [mock_text]
+ )
mock_message_content = [mock_message]
arguments = KernelArguments()
diff --git a/python/tests/unit/contents/test_chat_history.py b/python/tests/unit/contents/test_chat_history.py
index 33a8a1439712..89fdb1925b8e 100644
--- a/python/tests/unit/contents/test_chat_history.py
+++ b/python/tests/unit/contents/test_chat_history.py
@@ -213,9 +213,9 @@ def test_dump():
)
dump = chat_history.model_dump(exclude_none=True)
assert dump is not None
- assert dump["messages"][0]["role"] == "system"
+ assert dump["messages"][0]["role"] == AuthorRole.SYSTEM
assert dump["messages"][0]["items"][0]["text"] == system_msg
- assert dump["messages"][1]["role"] == "user"
+ assert dump["messages"][1]["role"] == AuthorRole.USER
assert dump["messages"][1]["items"][0]["text"] == "Message"
diff --git a/python/tests/unit/contents/test_chat_message_content.py b/python/tests/unit/contents/test_chat_message_content.py
index a2eeec17a9fb..38ee93d8e6bb 100644
--- a/python/tests/unit/contents/test_chat_message_content.py
+++ b/python/tests/unit/contents/test_chat_message_content.py
@@ -12,7 +12,7 @@
def test_cmc():
- message = ChatMessageContent(role="user", content="Hello, world!")
+ message = ChatMessageContent(role=AuthorRole.USER, content="Hello, world!")
assert message.role == AuthorRole.USER
assert message.content == "Hello, world!"
assert len(message.items) == 1
@@ -20,12 +20,13 @@ def test_cmc():
def test_cmc_str():
message = ChatMessageContent(role="user", content="Hello, world!")
+ assert message.role == AuthorRole.USER
assert str(message) == "Hello, world!"
def test_cmc_full():
message = ChatMessageContent(
- role="user",
+ role=AuthorRole.USER,
name="username",
content="Hello, world!",
inner_content="Hello, world!",
@@ -42,14 +43,14 @@ def test_cmc_full():
def test_cmc_items():
- message = ChatMessageContent(role="user", items=[TextContent(text="Hello, world!")])
+ message = ChatMessageContent(role=AuthorRole.USER, items=[TextContent(text="Hello, world!")])
assert message.role == AuthorRole.USER
assert message.content == "Hello, world!"
assert len(message.items) == 1
def test_cmc_items_and_content():
- message = ChatMessageContent(role="user", content="text", items=[TextContent(text="Hello, world!")])
+ message = ChatMessageContent(role=AuthorRole.USER, content="text", items=[TextContent(text="Hello, world!")])
assert message.role == AuthorRole.USER
assert message.content == "Hello, world!"
assert message.items[0].text == "Hello, world!"
@@ -59,7 +60,7 @@ def test_cmc_items_and_content():
def test_cmc_multiple_items():
message = ChatMessageContent(
- role="system",
+ role=AuthorRole.SYSTEM,
items=[
TextContent(text="Hello, world!"),
TextContent(text="Hello, world!"),
@@ -71,7 +72,7 @@ def test_cmc_multiple_items():
def test_cmc_content_set():
- message = ChatMessageContent(role="user", content="Hello, world!")
+ message = ChatMessageContent(role=AuthorRole.USER, content="Hello, world!")
assert message.role == AuthorRole.USER
assert message.content == "Hello, world!"
message.content = "Hello, world to you too!"
@@ -82,7 +83,7 @@ def test_cmc_content_set():
def test_cmc_content_set_empty():
- message = ChatMessageContent(role="user", content="Hello, world!")
+ message = ChatMessageContent(role=AuthorRole.USER, content="Hello, world!")
assert message.role == AuthorRole.USER
assert message.content == "Hello, world!"
message.items.pop()
@@ -92,7 +93,7 @@ def test_cmc_content_set_empty():
def test_cmc_to_element():
- message = ChatMessageContent(role="user", content="Hello, world!", name=None)
+ message = ChatMessageContent(role=AuthorRole.USER, content="Hello, world!", name=None)
element = message.to_element()
assert element.tag == "message"
assert element.attrib == {"role": "user"}
@@ -103,13 +104,13 @@ def test_cmc_to_element():
def test_cmc_to_prompt():
- message = ChatMessageContent(role="user", content="Hello, world!")
+ message = ChatMessageContent(role=AuthorRole.USER, content="Hello, world!")
prompt = message.to_prompt()
assert prompt == 'Hello, world!'
def test_cmc_from_element():
- element = ChatMessageContent(role="user", content="Hello, world!").to_element()
+ element = ChatMessageContent(role=AuthorRole.USER, content="Hello, world!").to_element()
message = ChatMessageContent.from_element(element)
assert message.role == AuthorRole.USER
assert message.content == "Hello, world!"
@@ -182,14 +183,14 @@ def test_cmc_from_element_content_parse(xml_content, user, text_content, length)
def test_cmc_serialize():
- message = ChatMessageContent(role="user", content="Hello, world!")
+ message = ChatMessageContent(role=AuthorRole.USER, content="Hello, world!")
dumped = message.model_dump()
- assert dumped["role"] == "user"
+ assert dumped["role"] == AuthorRole.USER
assert dumped["items"][0]["text"] == "Hello, world!"
def test_cmc_to_dict():
- message = ChatMessageContent(role="user", content="Hello, world!")
+ message = ChatMessageContent(role=AuthorRole.USER, content="Hello, world!")
assert message.to_dict() == {
"role": "user",
"content": "Hello, world!",
@@ -197,7 +198,7 @@ def test_cmc_to_dict():
def test_cmc_to_dict_keys():
- message = ChatMessageContent(role="user", content="Hello, world!")
+ message = ChatMessageContent(role=AuthorRole.USER, content="Hello, world!")
assert message.to_dict(role_key="author", content_key="text") == {
"author": "user",
"text": "Hello, world!",
diff --git a/python/tests/unit/contents/test_streaming_chat_message_content.py b/python/tests/unit/contents/test_streaming_chat_message_content.py
index a6d13430a37a..f09f6c3408be 100644
--- a/python/tests/unit/contents/test_streaming_chat_message_content.py
+++ b/python/tests/unit/contents/test_streaming_chat_message_content.py
@@ -15,7 +15,7 @@
def test_scmc():
- message = StreamingChatMessageContent(choice_index=0, role="user", content="Hello, world!")
+ message = StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, world!")
assert message.role == AuthorRole.USER
assert message.content == "Hello, world!"
assert len(message.items) == 1
@@ -29,7 +29,7 @@ def test_scmc_str():
def test_scmc_full():
message = StreamingChatMessageContent(
choice_index=0,
- role="user",
+ role=AuthorRole.USER,
name="username",
content="Hello, world!",
inner_content="Hello, world!",
@@ -46,7 +46,9 @@ def test_scmc_full():
def test_scmc_items():
- message = StreamingChatMessageContent(choice_index=0, role="user", items=[TextContent(text="Hello, world!")])
+ message = StreamingChatMessageContent(
+ choice_index=0, role=AuthorRole.USER, items=[TextContent(text="Hello, world!")]
+ )
assert message.role == AuthorRole.USER
assert message.content == "Hello, world!"
assert len(message.items) == 1
@@ -54,7 +56,7 @@ def test_scmc_items():
def test_scmc_items_and_content():
message = StreamingChatMessageContent(
- choice_index=0, role="user", content="text", items=[TextContent(text="Hello, world!")]
+ choice_index=0, role=AuthorRole.USER, content="text", items=[TextContent(text="Hello, world!")]
)
assert message.role == AuthorRole.USER
assert message.content == "Hello, world!"
@@ -66,7 +68,7 @@ def test_scmc_items_and_content():
def test_scmc_multiple_items():
message = StreamingChatMessageContent(
choice_index=0,
- role="system",
+ role=AuthorRole.SYSTEM,
items=[
TextContent(text="Hello, world!"),
TextContent(text="Hello, world!"),
@@ -78,7 +80,7 @@ def test_scmc_multiple_items():
def test_scmc_content_set():
- message = StreamingChatMessageContent(choice_index=0, role="user", content="Hello, world!")
+ message = StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, world!")
assert message.role == AuthorRole.USER
assert message.content == "Hello, world!"
message.content = "Hello, world to you too!"
@@ -89,7 +91,7 @@ def test_scmc_content_set():
def test_scmc_content_set_empty():
- message = StreamingChatMessageContent(choice_index=0, role="user", content="Hello, world!")
+ message = StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, world!")
assert message.role == AuthorRole.USER
assert message.content == "Hello, world!"
message.items.pop()
@@ -99,7 +101,7 @@ def test_scmc_content_set_empty():
def test_scmc_to_element():
- message = StreamingChatMessageContent(choice_index=0, role="user", content="Hello, world!", name=None)
+ message = StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, world!", name=None)
element = message.to_element()
assert element.tag == "message"
assert element.attrib == {"role": "user", "choice_index": "0"}
@@ -110,7 +112,7 @@ def test_scmc_to_element():
def test_scmc_to_prompt():
- message = StreamingChatMessageContent(choice_index=0, role="user", content="Hello, world!")
+ message = StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, world!")
prompt = message.to_prompt()
assert "Hello, world!" in prompt
assert 'choice_index="0"' in prompt
@@ -118,7 +120,7 @@ def test_scmc_to_prompt():
def test_scmc_from_element():
- element = StreamingChatMessageContent(choice_index=0, role="user", content="Hello, world!").to_element()
+ element = StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, world!").to_element()
message = StreamingChatMessageContent.from_element(element)
assert message.role == AuthorRole.USER
assert message.content == "Hello, world!"
@@ -188,14 +190,14 @@ def test_scmc_from_element_content_parse(xml_content, user, text_content, length
def test_scmc_serialize():
- message = StreamingChatMessageContent(choice_index=0, role="user", content="Hello, world!")
+ message = StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, world!")
dumped = message.model_dump()
- assert dumped["role"] == "user"
+ assert dumped["role"] == AuthorRole.USER
assert dumped["items"][0]["text"] == "Hello, world!"
def test_scmc_to_dict():
- message = StreamingChatMessageContent(choice_index=0, role="user", content="Hello, world!")
+ message = StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, world!")
assert message.to_dict() == {
"role": "user",
"content": "Hello, world!",
@@ -203,7 +205,7 @@ def test_scmc_to_dict():
def test_scmc_to_dict_keys():
- message = StreamingChatMessageContent(choice_index=0, role="user", content="Hello, world!")
+ message = StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, world!")
assert message.to_dict(role_key="author", content_key="text") == {
"author": "user",
"text": "Hello, world!",
@@ -254,8 +256,12 @@ def test_scmc_to_dict_items(input_args, expected_dict):
def test_scmc_add():
- message1 = StreamingChatMessageContent(choice_index=0, role="user", content="Hello, ", inner_content="source1")
- message2 = StreamingChatMessageContent(choice_index=0, role="user", content="world!", inner_content="source2")
+ message1 = StreamingChatMessageContent(
+ choice_index=0, role=AuthorRole.USER, content="Hello, ", inner_content="source1"
+ )
+ message2 = StreamingChatMessageContent(
+ choice_index=0, role=AuthorRole.USER, content="world!", inner_content="source2"
+ )
combined = message1 + message2
assert combined.role == AuthorRole.USER
assert combined.content == "Hello, world!"
@@ -264,9 +270,13 @@ def test_scmc_add():
def test_scmc_add_three():
- message1 = StreamingChatMessageContent(choice_index=0, role="user", content="Hello, ", inner_content="source1")
- message2 = StreamingChatMessageContent(choice_index=0, role="user", content="world", inner_content="source2")
- message3 = StreamingChatMessageContent(choice_index=0, role="user", content="!", inner_content="source3")
+ message1 = StreamingChatMessageContent(
+ choice_index=0, role=AuthorRole.USER, content="Hello, ", inner_content="source1"
+ )
+ message2 = StreamingChatMessageContent(
+ choice_index=0, role=AuthorRole.USER, content="world", inner_content="source2"
+ )
+ message3 = StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="!", inner_content="source3")
combined = message1 + message2 + message3
assert combined.role == AuthorRole.USER
assert combined.content == "Hello, world!"
@@ -277,13 +287,13 @@ def test_scmc_add_three():
def test_scmc_add_different_items():
message1 = StreamingChatMessageContent(
choice_index=0,
- role="user",
+ role=AuthorRole.USER,
items=[StreamingTextContent(choice_index=0, text="Hello, ")],
inner_content="source1",
)
message2 = StreamingChatMessageContent(
choice_index=0,
- role="user",
+ role=AuthorRole.USER,
items=[FunctionResultContent(id="test", name="test", result="test")],
inner_content="source2",
)
@@ -298,24 +308,24 @@ def test_scmc_add_different_items():
"message1, message2",
[
(
- StreamingChatMessageContent(choice_index=0, role="user", content="Hello, "),
- StreamingChatMessageContent(choice_index=0, role="assistant", content="world!"),
+ StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, "),
+ StreamingChatMessageContent(choice_index=0, role=AuthorRole.ASSISTANT, content="world!"),
),
(
- StreamingChatMessageContent(choice_index=0, role="user", content="Hello, "),
- StreamingChatMessageContent(choice_index=1, role="user", content="world!"),
+ StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, "),
+ StreamingChatMessageContent(choice_index=1, role=AuthorRole.USER, content="world!"),
),
(
- StreamingChatMessageContent(choice_index=0, role="user", content="Hello, ", ai_model_id="1234"),
- StreamingChatMessageContent(choice_index=0, role="user", content="world!", ai_model_id="5678"),
+ StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, ", ai_model_id="1234"),
+ StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="world!", ai_model_id="5678"),
),
(
- StreamingChatMessageContent(choice_index=0, role="user", content="Hello, ", encoding="utf-8"),
- StreamingChatMessageContent(choice_index=0, role="user", content="world!", encoding="utf-16"),
+ StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, ", encoding="utf-8"),
+ StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="world!", encoding="utf-16"),
),
(
- StreamingChatMessageContent(choice_index=0, role="user", content="Hello, "),
- ChatMessageContent(role="user", content="world!"),
+ StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, "),
+ ChatMessageContent(role=AuthorRole.USER, content="world!"),
),
],
ids=["different_roles", "different_index", "different_model", "different_encoding", "different_type"],
@@ -326,5 +336,5 @@ def test_smsc_add_exception(message1, message2):
def test_scmc_bytes():
- message = StreamingChatMessageContent(choice_index=0, role="user", content="Hello, world!")
+ message = StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, world!")
assert bytes(message) == b"Hello, world!"
diff --git a/python/tests/unit/functions/test_kernel_function_from_prompt.py b/python/tests/unit/functions/test_kernel_function_from_prompt.py
index 293ea5e28741..21abc647a0df 100644
--- a/python/tests/unit/functions/test_kernel_function_from_prompt.py
+++ b/python/tests/unit/functions/test_kernel_function_from_prompt.py
@@ -9,6 +9,7 @@
from semantic_kernel.connectors.ai.open_ai.services.open_ai_text_completion import OpenAITextCompletion
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
from semantic_kernel.const import METADATA_EXCEPTION_KEY
+from semantic_kernel.contents import AuthorRole
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent
from semantic_kernel.contents.text_content import TextContent
@@ -161,14 +162,16 @@ async def test_invoke_chat_stream(openai_unit_test_env):
with patch(
"semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion.OpenAIChatCompletion.get_chat_message_contents"
) as mock:
- mock.return_value = [ChatMessageContent(role="assistant", content="test", metadata={})]
+ mock.return_value = [ChatMessageContent(role=AuthorRole.ASSISTANT, content="test", metadata={})]
result = await function.invoke(kernel=kernel)
assert str(result) == "test"
with patch(
"semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion.OpenAIChatCompletion.get_streaming_chat_message_contents"
) as mock:
- mock.return_value = [StreamingChatMessageContent(choice_index=0, role="assistant", content="test", metadata={})]
+ mock.return_value = [
+ StreamingChatMessageContent(choice_index=0, role=AuthorRole.ASSISTANT, content="test", metadata={})
+ ]
async for result in function.invoke_stream(kernel=kernel):
assert str(result) == "test"
@@ -187,7 +190,7 @@ async def test_invoke_exception(openai_unit_test_env):
"semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion.OpenAIChatCompletion.get_chat_message_contents",
side_effect=Exception,
) as mock:
- mock.return_value = [ChatMessageContent(role="assistant", content="test", metadata={})]
+ mock.return_value = [ChatMessageContent(role=AuthorRole.ASSISTANT, content="test", metadata={})]
with pytest.raises(Exception, match="test"):
await function.invoke(kernel=kernel)
@@ -195,7 +198,9 @@ async def test_invoke_exception(openai_unit_test_env):
"semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion.OpenAIChatCompletion.get_streaming_chat_message_contents",
side_effect=Exception,
) as mock:
- mock.return_value = [StreamingChatMessageContent(choice_index=0, role="assistant", content="test", metadata={})]
+ mock.return_value = [
+ StreamingChatMessageContent(choice_index=0, role=AuthorRole.ASSISTANT, content="test", metadata={})
+ ]
with pytest.raises(Exception):
async for result in function.invoke_stream(kernel=kernel):
assert isinstance(result.metadata[METADATA_EXCEPTION_KEY], Exception)
@@ -270,7 +275,7 @@ async def test_invoke_defaults(openai_unit_test_env):
with patch(
"semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion.OpenAIChatCompletion.get_chat_message_contents"
) as mock:
- mock.return_value = [ChatMessageContent(role="assistant", content="test", metadata={})]
+ mock.return_value = [ChatMessageContent(role=AuthorRole.ASSISTANT, content="test", metadata={})]
result = await function.invoke(kernel=kernel)
assert str(result) == "test"
@@ -313,7 +318,7 @@ async def test_create_with_multiple_settings_one_service_registered(openai_unit_
with patch(
"semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion.OpenAIChatCompletion.get_chat_message_contents"
) as mock:
- mock.return_value = [ChatMessageContent(role="assistant", content="test", metadata={})]
+ mock.return_value = [ChatMessageContent(role=AuthorRole.ASSISTANT, content="test", metadata={})]
result = await function.invoke(kernel=kernel)
assert str(result) == "test"
diff --git a/python/tests/unit/planners/function_calling_stepwise_planner/test_function_calling_stepwise_planner.py b/python/tests/unit/planners/function_calling_stepwise_planner/test_function_calling_stepwise_planner.py
index 8092815094b5..0f89df2a62eb 100644
--- a/python/tests/unit/planners/function_calling_stepwise_planner/test_function_calling_stepwise_planner.py
+++ b/python/tests/unit/planners/function_calling_stepwise_planner/test_function_calling_stepwise_planner.py
@@ -5,6 +5,7 @@
import pytest
from semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion import OpenAIChatCompletion
+from semantic_kernel.contents import AuthorRole
from semantic_kernel.exceptions.planner_exceptions import PlannerInvalidConfigurationError
from semantic_kernel.functions.kernel_arguments import KernelArguments
from semantic_kernel.functions.kernel_function import KernelFunction
@@ -60,7 +61,7 @@ async def test_build_chat_history_for_step():
"goal", "initial_plan", kernel_mock, arguments_mock, service_mock
)
assert chat_history is not None
- assert chat_history[0].role == "user"
+ assert chat_history[0].role == AuthorRole.USER
@pytest.mark.asyncio
diff --git a/python/tests/unit/prompt_template/test_handlebars_prompt_template.py b/python/tests/unit/prompt_template/test_handlebars_prompt_template.py
index 387dd8458be8..13468f6d35a0 100644
--- a/python/tests/unit/prompt_template/test_handlebars_prompt_template.py
+++ b/python/tests/unit/prompt_template/test_handlebars_prompt_template.py
@@ -3,6 +3,7 @@
import pytest
from pytest import mark
+from semantic_kernel.contents import AuthorRole
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.function_call_content import FunctionCallContent
@@ -285,10 +286,12 @@ async def test_helpers_message_to_prompt(kernel: Kernel):
chat_history = ChatHistory()
chat_history.add_user_message("User message")
chat_history.add_message(
- ChatMessageContent(role="assistant", items=[FunctionCallContent(id="1", name="plug-test")])
+ ChatMessageContent(role=AuthorRole.ASSISTANT, items=[FunctionCallContent(id="1", name="plug-test")])
)
chat_history.add_message(
- ChatMessageContent(role="tool", items=[FunctionResultContent(id="1", name="plug-test", result="Tool message")])
+ ChatMessageContent(
+ role=AuthorRole.TOOL, items=[FunctionResultContent(id="1", name="plug-test", result="Tool message")]
+ )
)
rendered = await target.render(kernel, KernelArguments(chat_history=chat_history))
diff --git a/python/tests/unit/prompt_template/test_jinja2_prompt_template.py b/python/tests/unit/prompt_template/test_jinja2_prompt_template.py
index 33ed7ffdf5db..cb5a3a34e86e 100644
--- a/python/tests/unit/prompt_template/test_jinja2_prompt_template.py
+++ b/python/tests/unit/prompt_template/test_jinja2_prompt_template.py
@@ -3,6 +3,7 @@
import pytest
from pytest import mark
+from semantic_kernel.contents import AuthorRole
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.function_call_content import FunctionCallContent
@@ -293,10 +294,12 @@ async def test_helpers_message_to_prompt(kernel: Kernel):
chat_history = ChatHistory()
chat_history.add_user_message("User message")
chat_history.add_message(
- ChatMessageContent(role="assistant", items=[FunctionCallContent(id="1", name="plug-test")])
+ ChatMessageContent(role=AuthorRole.ASSISTANT, items=[FunctionCallContent(id="1", name="plug-test")])
)
chat_history.add_message(
- ChatMessageContent(role="tool", items=[FunctionResultContent(id="1", name="plug-test", result="Tool message")])
+ ChatMessageContent(
+ role=AuthorRole.TOOL, items=[FunctionResultContent(id="1", name="plug-test", result="Tool message")]
+ )
)
rendered = await target.render(kernel, KernelArguments(chat_history=chat_history))
diff --git a/python/tests/unit/prompt_template/test_prompt_template_e2e.py b/python/tests/unit/prompt_template/test_prompt_template_e2e.py
index 75000c462686..a145b29cbc04 100644
--- a/python/tests/unit/prompt_template/test_prompt_template_e2e.py
+++ b/python/tests/unit/prompt_template/test_prompt_template_e2e.py
@@ -5,6 +5,7 @@
from pytest import mark, raises
from semantic_kernel import Kernel
+from semantic_kernel.contents import AuthorRole
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.exceptions import TemplateSyntaxError
from semantic_kernel.functions import kernel_function
@@ -360,13 +361,13 @@ async def test_renders_and_can_be_parsed(kernel: Kernel):
).render(kernel, KernelArguments(unsafe_input=unsafe_input, safe_input=safe_input))
chat_history = ChatHistory.from_rendered_prompt(result)
assert chat_history
- assert chat_history.messages[0].role == "system"
+ assert chat_history.messages[0].role == AuthorRole.SYSTEM
assert chat_history.messages[0].content == "This is the system message"
- assert chat_history.messages[1].role == "user"
+ assert chat_history.messages[1].role == AuthorRole.USER
assert chat_history.messages[1].content == "This is the newer system message"
- assert chat_history.messages[2].role == "user"
+ assert chat_history.messages[2].role == AuthorRole.USER
assert chat_history.messages[2].content == "This is bold text"
- assert chat_history.messages[3].role == "user"
+ assert chat_history.messages[3].role == AuthorRole.USER
assert chat_history.messages[3].content == "This is the newest system message"
@@ -397,11 +398,11 @@ async def test_renders_and_can_be_parsed_with_cdata_sections(kernel: Kernel):
)
chat_history = ChatHistory.from_rendered_prompt(result)
assert chat_history
- assert chat_history.messages[0].role == "user"
+ assert chat_history.messages[0].role == AuthorRole.USER
assert chat_history.messages[0].content == "This is the newer system message"
- assert chat_history.messages[1].role == "user"
+ assert chat_history.messages[1].role == AuthorRole.USER
assert chat_history.messages[1].content == "explain imagehttps://fake-link-to-image/"
- assert chat_history.messages[2].role == "user"
+ assert chat_history.messages[2].role == AuthorRole.USER
assert (
chat_history.messages[2].content
== "]]>This is the newer system message