Skip to content

feat: add HistoryProcessors wrapper #2124

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions docs/message-history.md
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,42 @@ agent = Agent('openai:gpt-4o', history_processors=[filter_responses, summarize_o
In this case, the `filter_responses` processor will be applied first, and the
`summarize_old_messages` processor will be applied second.

### Modifying Message History

By default, history processors only modify the messages sent to the model without changing the original conversation history. However, you can use `HistoryProcessors` with `replace_history=True` to actually modify the original message history stored in the agent.

This is useful for scenarios like permanently compressing long conversations, implementing sliding window memory management, or removing sensitive information from the conversation history.

```python {title="modify_message_history_with_replace_history.py"}
from pydantic_ai import Agent, HistoryProcessors
from pydantic_ai.messages import ModelMessage

# Use a cheaper model to summarize old messages.
summarize_agent = Agent(
'openai:gpt-4o-mini',
instructions="""
Summarize this conversation, omitting small talk and unrelated topics.
Focus on the technical discussion and next steps.
""",
)


async def summarize_old_messages(messages: list[ModelMessage]) -> list[ModelMessage]:
# Summarize the oldest 10 messages
if len(messages) > 10:
oldest_messages = messages[:10]
summary = await summarize_agent.run(message_history=oldest_messages)
# Return the last message and the summary
return summary.new_messages() + messages[-1:]

return messages


agent = Agent('openai:gpt-4o', history_processors=HistoryProcessors(funcs=[summarize_old_messages], replace_history=True))
```

**Note:** When `replace_history=False` (the default), the behavior is the same as using a list of processors directly - the original conversation history remains unchanged.

## Examples

For a more complete example of using messages in conversations, see the [chat app](examples/chat-app.md) example.
11 changes: 10 additions & 1 deletion pydantic_ai_slim/pydantic_ai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
from importlib.metadata import version as _metadata_version

from .agent import Agent, CallToolsNode, EndStrategy, ModelRequestNode, UserPromptNode, capture_run_messages
from .agent import (
Agent,
CallToolsNode,
EndStrategy,
HistoryProcessors,
ModelRequestNode,
UserPromptNode,
capture_run_messages,
)
from .exceptions import (
AgentRunError,
FallbackExceptionGroup,
Expand All @@ -19,6 +27,7 @@
'__version__',
# agent
'Agent',
'HistoryProcessors',
'EndStrategy',
'CallToolsNode',
'ModelRequestNode',
Expand Down
30 changes: 20 additions & 10 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
'build_run_context',
'capture_run_messages',
'HistoryProcessor',
'HistoryProcessors',
)


Expand Down Expand Up @@ -68,6 +69,16 @@
"""


@dataclasses.dataclass
class HistoryProcessors(Generic[DepsT]):
"""A wrapper for a list of history processors."""

funcs: list[HistoryProcessor[DepsT]]
"""A list of functions to process the message history."""
replace_history: bool = False
"""Whether to replace the message history with the processed history."""


@dataclasses.dataclass
class GraphAgentState:
"""State kept across the execution of the agent graph."""
Expand Down Expand Up @@ -106,7 +117,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
output_schema: _output.OutputSchema[OutputDataT]
output_validators: list[_output.OutputValidator[DepsT, OutputDataT]]

history_processors: Sequence[HistoryProcessor[DepsT]]
history_processors: HistoryProcessors[DepsT]

function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
Expand Down Expand Up @@ -358,9 +369,7 @@ async def _stream(

model_settings, model_request_parameters = await self._prepare_request(ctx)
model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters)
message_history = await _process_message_history(
ctx.state.message_history, ctx.deps.history_processors, build_run_context(ctx)
)
message_history = await _process_message_history(ctx.state, ctx.deps.history_processors, build_run_context(ctx))
async with ctx.deps.model.request_stream(
message_history, model_settings, model_request_parameters
) as streamed_response:
Expand All @@ -384,9 +393,7 @@ async def _make_request(

model_settings, model_request_parameters = await self._prepare_request(ctx)
model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters)
message_history = await _process_message_history(
ctx.state.message_history, ctx.deps.history_processors, build_run_context(ctx)
)
message_history = await _process_message_history(ctx.state, ctx.deps.history_processors, build_run_context(ctx))
model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters)
ctx.state.usage.incr(_usage.Usage())

Expand Down Expand Up @@ -955,12 +962,13 @@ def build_agent_graph(


async def _process_message_history(
messages: list[_messages.ModelMessage],
processors: Sequence[HistoryProcessor[DepsT]],
state: GraphAgentState,
processors: HistoryProcessors[DepsT],
run_context: RunContext[DepsT],
) -> list[_messages.ModelMessage]:
"""Process message history through a sequence of processors."""
for processor in processors:
messages = state.message_history
for processor in processors.funcs:
takes_ctx = is_takes_ctx(processor)

if is_async_callable(processor):
Expand All @@ -976,4 +984,6 @@ async def _process_message_history(
else:
sync_processor = cast(_HistoryProcessorSync, processor)
messages = await run_in_executor(sync_processor, messages)
if processors.replace_history:
state.message_history = messages
return messages
17 changes: 12 additions & 5 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
result,
usage as _usage,
)
from ._agent_graph import HistoryProcessor
from ._agent_graph import HistoryProcessor, HistoryProcessors
from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model
from .output import OutputDataT, OutputSpec
from .result import FinalResult, StreamedRunResult
Expand Down Expand Up @@ -78,6 +78,7 @@
'ModelRequestNode',
'UserPromptNode',
'InstrumentationSettings',
'HistoryProcessors',
)


Expand Down Expand Up @@ -181,7 +182,7 @@ def __init__(
defer_model_check: bool = False,
end_strategy: EndStrategy = 'early',
instrument: InstrumentationSettings | bool | None = None,
history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None,
history_processors: Sequence[HistoryProcessor[AgentDepsT]] | HistoryProcessors[AgentDepsT] | None = None,
) -> None: ...

@overload
Expand Down Expand Up @@ -211,7 +212,7 @@ def __init__(
defer_model_check: bool = False,
end_strategy: EndStrategy = 'early',
instrument: InstrumentationSettings | bool | None = None,
history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None,
history_processors: Sequence[HistoryProcessor[AgentDepsT]] | HistoryProcessors[AgentDepsT] | None = None,
) -> None: ...

def __init__(
Expand All @@ -236,7 +237,7 @@ def __init__(
defer_model_check: bool = False,
end_strategy: EndStrategy = 'early',
instrument: InstrumentationSettings | bool | None = None,
history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None,
history_processors: Sequence[HistoryProcessor[AgentDepsT]] | HistoryProcessors[AgentDepsT] | None = None,
**_deprecated_kwargs: Any,
):
"""Create an agent.
Expand Down Expand Up @@ -359,7 +360,13 @@ def __init__(
self._max_result_retries = output_retries if output_retries is not None else retries
self._mcp_servers = mcp_servers
self._prepare_tools = prepare_tools
self.history_processors = history_processors or []
history_processors = history_processors or []
self.history_processors = cast(
HistoryProcessors[AgentDepsT],
HistoryProcessors(funcs=list(history_processors))
if not isinstance(history_processors, HistoryProcessors)
else history_processors,
)
for tool in tools:
if isinstance(tool, Tool):
self._register_tool(tool)
Expand Down
110 changes: 109 additions & 1 deletion tests/test_history_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from inline_snapshot import snapshot

from pydantic_ai import Agent
from pydantic_ai import Agent, HistoryProcessors
from pydantic_ai.messages import ModelMessage, ModelRequest, ModelRequestPart, ModelResponse, TextPart, UserPromptPart
from pydantic_ai.models.function import AgentInfo, FunctionModel
from pydantic_ai.tools import RunContext
Expand Down Expand Up @@ -301,3 +301,111 @@ class Deps:
user_part = msg.parts[0]
assert isinstance(user_part, UserPromptPart)
assert cast(str, user_part.content).startswith('TEST: ')


async def test_history_processors_replace_history_true(function_model: FunctionModel):
"""Test HistoryProcessors with replace_history=True modifies original history."""

def keep_only_requests(messages: list[ModelMessage]) -> list[ModelMessage]:
return [msg for msg in messages if isinstance(msg, ModelRequest)]

processors = HistoryProcessors(funcs=[keep_only_requests], replace_history=True) # type: ignore
agent = Agent(function_model, history_processors=processors) # type: ignore

original_history = [
ModelRequest(parts=[UserPromptPart(content='Question 1')]),
ModelResponse(parts=[TextPart(content='Answer 1')]),
ModelRequest(parts=[UserPromptPart(content='Question 2')]),
ModelResponse(parts=[TextPart(content='Answer 2')]),
]

result = await agent.run('Question 3', message_history=original_history.copy())

# Verify the history was modified - responses should be removed
all_messages = result.all_messages()
requests = [msg for msg in all_messages if isinstance(msg, ModelRequest)]
responses = [msg for msg in all_messages if isinstance(msg, ModelResponse)]

# Should have 3 requests (2 original + 1 new) and 1 response (only the new one)
assert len(requests) == 3
assert len(responses) == 1


async def test_history_processors_multiple_with_replace_history(function_model: FunctionModel):
"""Test multiple processors with replace_history=True."""

def remove_responses(messages: list[ModelMessage]) -> list[ModelMessage]:
return [msg for msg in messages if isinstance(msg, ModelRequest)]

def keep_recent(messages: list[ModelMessage]) -> list[ModelMessage]:
return messages[-2:] if len(messages) > 2 else messages

processors = HistoryProcessors( # type: ignore
funcs=[remove_responses, keep_recent], replace_history=True
)
agent = Agent(function_model, history_processors=processors) # type: ignore

# Create history with 4 requests and 4 responses
original_history: list[ModelMessage] = []
for i in range(4):
original_history.append(ModelRequest(parts=[UserPromptPart(content=f'Question {i + 1}')]))
original_history.append(ModelResponse(parts=[TextPart(content=f'Answer {i + 1}')]))

result = await agent.run('Final question', message_history=original_history.copy())

# After processing: remove responses -> keep recent 2 -> add new exchange
all_messages = result.all_messages()
requests = [msg for msg in all_messages if isinstance(msg, ModelRequest)]
responses = [msg for msg in all_messages if isinstance(msg, ModelResponse)]

# Should have 2 requests (1 requests + 1 new) and 1 response (new only), responses should be removed
assert len(requests) == 2
assert len(responses) == 1


async def test_history_processors_streaming_with_replace_history(function_model: FunctionModel):
"""Test replace_history=True works with streaming runs."""

def summarize_history(messages: list[ModelMessage]) -> list[ModelMessage]:
# Simple summarization - keep only the last message
return messages[-1:] if messages else []

processors = HistoryProcessors(funcs=[summarize_history], replace_history=True) # type: ignore
agent = Agent(function_model, history_processors=processors) # type: ignore

original_history = [
ModelRequest(parts=[UserPromptPart(content='Question 1')]),
ModelResponse(parts=[TextPart(content='Answer 1')]),
ModelRequest(parts=[UserPromptPart(content='Question 2')]),
ModelResponse(parts=[TextPart(content='Answer 2')]),
]

async with agent.run_stream('Question 3', message_history=original_history.copy()) as result:
async for _ in result.stream_text():
pass

# Verify history was modified during streaming
all_messages = result.all_messages()
# Should only have: new request + new response = 2 total
assert len(all_messages) == 2


async def test_history_processors_replace_history_false_default(function_model: FunctionModel):
"""Test HistoryProcessors with replace_history=False (default) preserves original history."""

def keep_only_requests(messages: list[ModelMessage]) -> list[ModelMessage]:
return [msg for msg in messages if isinstance(msg, ModelRequest)]

processors = HistoryProcessors(funcs=[keep_only_requests]) # replace_history=False by default # type: ignore
agent = Agent(function_model, history_processors=processors) # type: ignore

original_history = [
ModelRequest(parts=[UserPromptPart(content='Question 1')]),
ModelResponse(parts=[TextPart(content='Answer 1')]),
]

result = await agent.run('Question 2', message_history=original_history.copy())

# Verify original history is preserved
all_messages = result.all_messages()
assert len(all_messages) == 4 # 2 original + 1 new request + 1 new response