Skip to content

Python: Update chat history channel to produce the correct messages with a visibility bool #7619

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 5, 2024
7 changes: 3 additions & 4 deletions python/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions python/semantic_kernel/agents/agent_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ async def receive(
def invoke(
self,
agent: "Agent",
) -> AsyncIterable["ChatMessageContent"]:
) -> AsyncIterable[tuple[bool, "ChatMessageContent"]]:
"""Perform a discrete incremental interaction between a single Agent and AgentChat.

Args:
agent: The agent to interact with.

Returns:
An async iterable of ChatMessageContent.
A async iterable of a bool, ChatMessageContent.
"""
...

Expand Down
50 changes: 45 additions & 5 deletions python/semantic_kernel/agents/chat_history_channel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.

import sys
from collections import deque
from collections.abc import AsyncIterable

if sys.version_info >= (3, 12):
Expand All @@ -9,12 +10,14 @@
from typing_extensions import override # pragma: no cover

from abc import abstractmethod
from typing import TYPE_CHECKING, Protocol, runtime_checkable
from typing import TYPE_CHECKING, Deque, Protocol, runtime_checkable

from semantic_kernel.agents.agent import Agent
from semantic_kernel.agents.agent_channel import AgentChannel
from semantic_kernel.contents import ChatMessageContent
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.function_call_content import FunctionCallContent
from semantic_kernel.contents.function_result_content import FunctionResultContent
from semantic_kernel.exceptions import ServiceInvalidTypeError
from semantic_kernel.utils.experimental_decorator import experimental_class

Expand Down Expand Up @@ -48,7 +51,7 @@ class ChatHistoryChannel(AgentChannel, ChatHistory):
async def invoke(
self,
agent: Agent,
) -> AsyncIterable[ChatMessageContent]:
) -> AsyncIterable[tuple[bool, ChatMessageContent]]:
"""Perform a discrete incremental interaction between a single Agent and AgentChat.

Args:
Expand All @@ -63,9 +66,46 @@ async def invoke(
f"Invalid channel binding for agent with id: `{id}` with name: ({type(agent).__name__})"
)

async for message in agent.invoke(self):
self.messages.append(message)
yield message
message_count = len(self.messages)
mutated_history = set()
message_queue: Deque[ChatMessageContent] = deque()

async for response_message in agent.invoke(self):
# Capture all messages that have been included in the mutated history.
for message_index in range(message_count, len(self.messages)):
mutated_message = self.messages[message_index]
mutated_history.add(mutated_message)
message_queue.append(mutated_message)

# Update the message count pointer to reflect the current history.
message_count = len(self.messages)

# Avoid duplicating any message included in the mutated history and also returned by the enumeration result.
if response_message not in mutated_history:
self.messages.append(response_message)
message_queue.append(response_message)

# Dequeue the next message to yield.
yield_message = message_queue.popleft()
yield (
self._is_message_visible(message=yield_message, message_queue_count=len(message_queue)),
yield_message,
)

# Dequeue any remaining messages to yield.
while message_queue:
yield_message = message_queue.popleft()
yield (
self._is_message_visible(message=yield_message, message_queue_count=len(message_queue)),
yield_message,
)

def _is_message_visible(self, message: ChatMessageContent, message_queue_count: int) -> bool:
"""Determine if a message is visible to the user."""
return (
not any(isinstance(item, (FunctionCallContent, FunctionResultContent)) for item in message.items)
or message_queue_count == 0
)

@override
async def receive(
Expand Down
4 changes: 4 additions & 0 deletions python/semantic_kernel/contents/chat_message_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,7 @@ def _parse_items(self) -> str | list[dict[str, Any]]:
if len(self.items) == 1 and isinstance(self.items[0], FunctionResultContent):
return str(self.items[0].result)
return [item.to_dict() for item in self.items]

def __hash__(self) -> int:
"""Return the hash of the chat message content."""
return hash((self.tag, self.role, self.content, self.encoding, self.finish_reason, *self.items))
4 changes: 4 additions & 0 deletions python/semantic_kernel/contents/function_call_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,7 @@ def to_dict(self) -> dict[str, str | Any]:
"""Convert the instance to a dictionary."""
args = json.dumps(self.arguments) if isinstance(self.arguments, dict) else self.arguments
return {"id": self.id, "type": "function", "function": {"name": self.name, "arguments": args}}

def __hash__(self) -> int:
"""Return the hash of the function call content."""
return hash((self.tag, self.id, self.index, self.name, self.function_name, self.plugin_name, self.arguments))
4 changes: 4 additions & 0 deletions python/semantic_kernel/contents/function_result_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,7 @@ def split_name(self) -> list[str]:
def serialize_result(self, value: Any) -> str:
"""Serialize the result."""
return str(value)

def __hash__(self) -> int:
"""Return the hash of the function result content."""
return hash((self.tag, self.id, self.result, self.name, self.function_name, self.plugin_name, self.encoding))
4 changes: 4 additions & 0 deletions python/semantic_kernel/contents/text_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,7 @@ def from_element(cls: type[_T], element: Element) -> _T:
def to_dict(self) -> dict[str, str]:
"""Convert the instance to a dictionary."""
return {"type": "text", "text": self.text}

def __hash__(self) -> int:
"""Return the hash of the text content."""
return hash((self.tag, self.text, self.encoding))
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ def services() -> dict[str, tuple[ChatCompletionClientBase | None, type[PromptEx
ChatMessageContent(role=AuthorRole.USER, items=[TextContent(text="How are you today?")]),
],
["Hello", "well"],
marks=pytest.mark.skip(reason="Skipping due to 429s from Google AI."),
id="google_ai_text_input",
),
pytest.param(
Expand Down Expand Up @@ -551,6 +552,7 @@ def services() -> dict[str, tuple[ChatCompletionClientBase | None, type[PromptEx
],
],
["1.2"],
marks=pytest.mark.skip(reason="Skipping due to 429s from Google AI."),
id="google_ai_tool_call_flow",
),
pytest.param(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def services() -> dict[str, tuple[ChatCompletionClientBase, type[PromptExecution
{},
["Repeat the word Hello"],
["Hello"],
marks=pytest.mark.skip(reason="Skipping due to 429s from Google AI."),
id="google_ai_text_input",
),
pytest.param(
Expand Down
68 changes: 61 additions & 7 deletions python/tests/unit/agents/test_chat_history_channel.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# Copyright (c) Microsoft. All rights reserved.

from collections.abc import AsyncIterable
from unittest.mock import AsyncMock

import pytest

from semantic_kernel.agents.chat_history_channel import ChatHistoryAgentProtocol, ChatHistoryChannel
from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent
from semantic_kernel.contents.function_result_content import FunctionResultContent
from semantic_kernel.contents.utils.author_role import AuthorRole
from semantic_kernel.exceptions import ServiceInvalidTypeError

Expand All @@ -18,9 +19,6 @@ async def invoke(self, history: list[ChatMessageContent]) -> AsyncIterable[ChatM
for message in history:
yield ChatMessageContent(role=AuthorRole.SYSTEM, content=f"Processed: {message.content}")

async def invoke_stream(self, history: list[ChatMessageContent]) -> AsyncIterable["StreamingChatMessageContent"]:
pass


class MockNonChatHistoryHandler:
"""Mock agent to test incorrect instance handling."""
Expand All @@ -31,23 +29,79 @@ class MockNonChatHistoryHandler:
ChatHistoryAgentProtocol.register(MockChatHistoryHandler)


class AsyncIterableMock:
def __init__(self, async_gen):
self.async_gen = async_gen

def __aiter__(self):
return self.async_gen()


@pytest.mark.asyncio
async def test_invoke():
channel = ChatHistoryChannel()
agent = MockChatHistoryHandler()
agent = AsyncMock(spec=MockChatHistoryHandler)

async def mock_invoke(history: list[ChatMessageContent]):
for message in history:
yield ChatMessageContent(role=AuthorRole.SYSTEM, content=f"Processed: {message.content}")

agent.invoke.return_value = AsyncIterableMock(
lambda: mock_invoke([ChatMessageContent(role=AuthorRole.USER, content="Initial message")])
)

initial_message = ChatMessageContent(role=AuthorRole.USER, content="Initial message")
channel.messages.append(initial_message)

received_messages = []
async for message in channel.invoke(agent):
async for is_visible, message in channel.invoke(agent):
received_messages.append(message)
break # only process one message for the test
assert is_visible

assert len(received_messages) == 1
assert "Processed: Initial message" in received_messages[0].content


@pytest.mark.asyncio
async def test_invoke_leftover_in_queue():
channel = ChatHistoryChannel()
agent = AsyncMock(spec=MockChatHistoryHandler)

async def mock_invoke(history: list[ChatMessageContent]):
for message in history:
yield ChatMessageContent(role=AuthorRole.SYSTEM, content=f"Processed: {message.content}")
yield ChatMessageContent(
role=AuthorRole.SYSTEM, content="Final message", items=[FunctionResultContent(id="test_id", result="test")]
)

agent.invoke.return_value = AsyncIterableMock(
lambda: mock_invoke(
[
ChatMessageContent(
role=AuthorRole.USER,
content="Initial message",
items=[FunctionResultContent(id="test_id", result="test")],
)
]
)
)

initial_message = ChatMessageContent(role=AuthorRole.USER, content="Initial message")
channel.messages.append(initial_message)

received_messages = []
async for is_visible, message in channel.invoke(agent):
received_messages.append(message)
assert is_visible
if len(received_messages) >= 3:
break

assert len(received_messages) == 3
assert "Processed: Initial message" in received_messages[0].content
assert "Final message" in received_messages[2].content
assert received_messages[2].items[0].id == "test_id"


@pytest.mark.asyncio
async def test_invoke_incorrect_instance_throws():
channel = ChatHistoryChannel()
Expand Down
3 changes: 3 additions & 0 deletions python/tests/unit/agents/test_open_ai_assistant_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def mock_run_failed():
object="thread.run",
thread_id="thread_id",
tools=[],
parallel_tool_calls=True,
)


Expand Down Expand Up @@ -214,6 +215,7 @@ def mock_run_required_action():
]
),
),
parallel_tool_calls=True,
)


Expand All @@ -239,6 +241,7 @@ def mock_run_completed():
]
),
),
parallel_tool_calls=True,
)


Expand Down
Loading