Skip to content

Add Automatic Context Summarization to ClientSession #1175

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dependencies = [
"uvicorn>=0.23.1; sys_platform != 'emscripten'",
"jsonschema>=4.20.0",
"pywin32>=310; sys_platform == 'win32'",
"tiktoken>=0.9.0",
]

[project.optional-dependencies]
Expand All @@ -59,6 +60,7 @@ dev = [
"pytest-pretty>=1.2.0",
"inline-snapshot>=0.23.0",
"dirty-equals>=0.9.0",
"pytest-asyncio>=1.1.0",
]
docs = [
"mkdocs>=1.6.1",
Expand Down
103 changes: 103 additions & 0 deletions src/mcp/client/client_session_summarizing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from datetime import timedelta
from typing import Any

import tiktoken
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

from mcp.client.session import ClientSession
from mcp.shared.context import RequestContext
from mcp.shared.message import SessionMessage
from mcp.types import CreateMessageRequestParams, CreateMessageResult, SamplingMessage, TextContent

DEFAULT_MAX_TOKENS = 4000
DEFAULT_SUMMARIZE_THRESHOLD = 0.8
DEFAULT_SUMMARY_PROMPT = "Summarize the following conversation succinctly, preserving key facts:\n\n"


class ClientSessionSummarizing(ClientSession):
def __init__(
self,
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
write_stream: MemoryObjectSendStream[SessionMessage],
read_timeout_seconds: timedelta | None = None,
sampling_callback: Any | None = None,
elicitation_callback: Any | None = None,
list_roots_callback: Any | None = None,
logging_callback: Any | None = None,
message_handler: Any | None = None,
client_info: Any | None = None,
max_tokens: int | None = None,
summarize_threshold: float | None = None,
summary_prompt: str | None = None,
) -> None:
super().__init__(
read_stream=read_stream,
write_stream=write_stream,
read_timeout_seconds=read_timeout_seconds,
sampling_callback=sampling_callback,
elicitation_callback=elicitation_callback,
list_roots_callback=list_roots_callback,
logging_callback=logging_callback,
message_handler=message_handler,
client_info=client_info,
)
self.history: list[SamplingMessage] = []
self.max_tokens = max_tokens or DEFAULT_MAX_TOKENS
self.summarize_threshold = summarize_threshold or DEFAULT_SUMMARIZE_THRESHOLD
self.summary_prompt = summary_prompt or DEFAULT_SUMMARY_PROMPT
# Override the sampling callback to include our summarization logic
self._sampling_callback = self._summarizing_sampling_callback

async def _summarizing_sampling_callback(
self,
context: RequestContext["ClientSession", Any],
params: CreateMessageRequestParams,
) -> CreateMessageResult:
"""Custom sampling callback that includes summarization logic."""
# Add messages to history
self.history.extend(params.messages)

# Check if we need to summarize
if self.token_count() > self.max_tokens * self.summarize_threshold:
await self.summarize_context()

# For now, return a simple response
# In a real implementation, you might want to call an LLM service here
return CreateMessageResult(
role="assistant",
content=TextContent(type="text", text="Message processed with summarization"),
model="summarizing-model",
stopReason="endTurn",
)

def token_count(self) -> int:
"""Calculate token count for all messages in history."""
tokenizer = tiktoken.get_encoding("cl100k_base")
total_tokens = 0

for message in self.history:
if isinstance(message.content, TextContent):
total_tokens += len(tokenizer.encode(message.content.text))
elif isinstance(message.content, str):
total_tokens += len(tokenizer.encode(message.content))

return total_tokens

async def summarize_context(self) -> None:
"""Summarize the conversation history and replace it with a summary."""
if not self.history:
return

# Create a summary prompt from all messages
summary_text = self.summary_prompt
for message in self.history:
if isinstance(message.content, TextContent):
summary_text += f"{message.role}: {message.content.text}\n"
elif isinstance(message.content, str):
summary_text += f"{message.role}: {message.content}\n"

# Create a summary message
summary_message = SamplingMessage(role="assistant", content=TextContent(type="text", text=summary_text))

# Replace history with summary
self.history = [summary_message]
196 changes: 196 additions & 0 deletions tests/client/test_client_session_summarizing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
from typing import Any

import anyio
import pytest

from mcp.client.client_session_summarizing import (
DEFAULT_MAX_TOKENS,
DEFAULT_SUMMARIZE_THRESHOLD,
DEFAULT_SUMMARY_PROMPT,
ClientSessionSummarizing,
)
from mcp.shared.context import RequestContext
from mcp.types import (
CreateMessageRequestParams,
CreateMessageResult,
SamplingMessage,
TextContent,
)


@pytest.mark.asyncio
async def test_summarizing_session():
send_stream, receive_stream = anyio.create_memory_object_stream(10)
try:
session = ClientSessionSummarizing(
read_stream=receive_stream,
write_stream=send_stream,
)

# Create real messages instead of simple strings
messages = [SamplingMessage(role="user", content=TextContent(type="text", text="Hello")) for _ in range(3500)]
session.history = messages # Simulate approaching token limit

assert session.token_count() > session.max_tokens * session.summarize_threshold

# Test that summarization works
await session.summarize_context()

# After summarization, history should contain only one message
assert len(session.history) == 1
assert isinstance(session.history[0], SamplingMessage)
assert session.history[0].role == "assistant"

finally:
await send_stream.aclose()
await receive_stream.aclose()


@pytest.mark.asyncio
async def test_sampling_callback():
"""Test sampling callback with ClientSessionSummarizing"""
send_stream, receive_stream = anyio.create_memory_object_stream(10)
try:
session = ClientSessionSummarizing(
read_stream=receive_stream,
write_stream=send_stream,
)

# Create request parameters
params = CreateMessageRequestParams(
messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello world"))], maxTokens=100
)

# Create simple context for testing
context: Any = RequestContext(session=session, request_id=1, meta=None, lifespan_context=None)

# Call sampling callback
result = await session._summarizing_sampling_callback(context, params)

# Verify the result is correct
assert isinstance(result, CreateMessageResult)
assert result.role == "assistant"
assert isinstance(result.content, TextContent)
assert "Message processed with summarization" in result.content.text

# Verify message was added to history
assert len(session.history) == 1
assert session.history[0].role == "user"
assert isinstance(session.history[0].content, TextContent)
assert session.history[0].content.text == "Hello world"

finally:
await send_stream.aclose()
await receive_stream.aclose()


@pytest.mark.asyncio
async def test_custom_summary_prompt():
"""Test that user can define custom prompt"""
send_stream, receive_stream = anyio.create_memory_object_stream(10)
try:
custom_prompt = "Custom summary prompt:\n\n"
session = ClientSessionSummarizing(
read_stream=receive_stream,
write_stream=send_stream,
summary_prompt=custom_prompt,
)

# Verify user can define custom prompt
assert session.summary_prompt == custom_prompt
assert session.summary_prompt != DEFAULT_SUMMARY_PROMPT

# Test that summarization uses custom prompt
session.history = [SamplingMessage(role="user", content=TextContent(type="text", text="Test message"))]

await session.summarize_context()

# Verify summary contains custom prompt
assert len(session.history) == 1
summary_content = session.history[0].content
assert isinstance(summary_content, TextContent)
assert custom_prompt in summary_content.text

finally:
await send_stream.aclose()
await receive_stream.aclose()


@pytest.mark.asyncio
async def test_default_summary_prompt():
"""Test that user gets default prompt if not specified"""
send_stream, receive_stream = anyio.create_memory_object_stream(10)
try:
session = ClientSessionSummarizing(
read_stream=receive_stream,
write_stream=send_stream,
)

# Verify user gets default prompt
assert session.summary_prompt == DEFAULT_SUMMARY_PROMPT

finally:
await send_stream.aclose()
await receive_stream.aclose()


@pytest.mark.asyncio
async def test_custom_max_tokens():
"""Test that user can define custom max tokens"""
send_stream, receive_stream = anyio.create_memory_object_stream(10)
try:
custom_max_tokens = 2000
session = ClientSessionSummarizing(
read_stream=receive_stream,
write_stream=send_stream,
max_tokens=custom_max_tokens,
)

# Verify user can define custom max tokens
assert session.max_tokens == custom_max_tokens
assert session.max_tokens != DEFAULT_MAX_TOKENS

finally:
await send_stream.aclose()
await receive_stream.aclose()


@pytest.mark.asyncio
async def test_custom_summarize_threshold():
"""Test that user can define custom summarize threshold"""
send_stream, receive_stream = anyio.create_memory_object_stream(10)
try:
custom_threshold = 0.5
session = ClientSessionSummarizing(
read_stream=receive_stream,
write_stream=send_stream,
summarize_threshold=custom_threshold,
)

# Verify user can define custom threshold
assert session.summarize_threshold == custom_threshold
assert session.summarize_threshold != DEFAULT_SUMMARIZE_THRESHOLD

finally:
await send_stream.aclose()
await receive_stream.aclose()


@pytest.mark.asyncio
async def test_default_parameters():
"""Test that user gets default parameters if not specified"""
send_stream, receive_stream = anyio.create_memory_object_stream(10)
try:
session = ClientSessionSummarizing(
read_stream=receive_stream,
write_stream=send_stream,
)

# Verify user gets default parameters
assert session.max_tokens == DEFAULT_MAX_TOKENS
assert session.summarize_threshold == DEFAULT_SUMMARIZE_THRESHOLD
assert session.summary_prompt == DEFAULT_SUMMARY_PROMPT

finally:
await send_stream.aclose()
await receive_stream.aclose()
Loading
Loading