diff --git a/docs/message-history.md b/docs/message-history.md index 179c4c291..11a1f8291 100644 --- a/docs/message-history.md +++ b/docs/message-history.md @@ -522,6 +522,42 @@ agent = Agent('openai:gpt-4o', history_processors=[filter_responses, summarize_o In this case, the `filter_responses` processor will be applied first, and the `summarize_old_messages` processor will be applied second. +### Modifying Message History + +By default, history processors only modify the messages sent to the model without changing the original conversation history. However, you can use `HistoryProcessors` with `replace_history=True` to actually modify the original message history stored in the agent. + +This is useful for scenarios like permanently compressing long conversations, implementing sliding window memory management, or removing sensitive information from the conversation history. + +```python {title="modify_message_history_with_replace_history.py"} +from pydantic_ai import Agent, HistoryProcessors +from pydantic_ai.messages import ModelMessage + +# Use a cheaper model to summarize old messages. +summarize_agent = Agent( + 'openai:gpt-4o-mini', + instructions=""" +Summarize this conversation, omitting small talk and unrelated topics. +Focus on the technical discussion and next steps. +""", +) + + +async def summarize_old_messages(messages: list[ModelMessage]) -> list[ModelMessage]: + # Summarize the oldest 10 messages + if len(messages) > 10: + oldest_messages = messages[:10] + summary = await summarize_agent.run(message_history=oldest_messages) + # Return the last message and the summary + return summary.new_messages() + messages[-1:] + + return messages + + +agent = Agent('openai:gpt-4o', history_processors=HistoryProcessors(funcs=[summarize_old_messages], replace_history=True)) +``` + +**Note:** When `replace_history=False` (the default), the behavior is the same as using a list of processors directly - the original conversation history remains unchanged. + ## Examples For a more complete example of using messages in conversations, see the [chat app](examples/chat-app.md) example. diff --git a/pydantic_ai_slim/pydantic_ai/__init__.py b/pydantic_ai_slim/pydantic_ai/__init__.py index aa50774d0..294558864 100644 --- a/pydantic_ai_slim/pydantic_ai/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/__init__.py @@ -1,6 +1,14 @@ from importlib.metadata import version as _metadata_version -from .agent import Agent, CallToolsNode, EndStrategy, ModelRequestNode, UserPromptNode, capture_run_messages +from .agent import ( + Agent, + CallToolsNode, + EndStrategy, + HistoryProcessors, + ModelRequestNode, + UserPromptNode, + capture_run_messages, +) from .exceptions import ( AgentRunError, FallbackExceptionGroup, @@ -19,6 +27,7 @@ '__version__', # agent 'Agent', + 'HistoryProcessors', 'EndStrategy', 'CallToolsNode', 'ModelRequestNode', diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 4515d18bc..5b6f84030 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -35,6 +35,7 @@ 'build_run_context', 'capture_run_messages', 'HistoryProcessor', + 'HistoryProcessors', ) @@ -68,6 +69,16 @@ """ +@dataclasses.dataclass +class HistoryProcessors(Generic[DepsT]): + """A wrapper for a list of history processors.""" + + funcs: list[HistoryProcessor[DepsT]] + """A list of functions to process the message history.""" + replace_history: bool = False + """Whether to replace the message history with the processed history.""" + + @dataclasses.dataclass class GraphAgentState: """State kept across the execution of the agent graph.""" @@ -106,7 +117,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]): output_schema: _output.OutputSchema[OutputDataT] output_validators: list[_output.OutputValidator[DepsT, OutputDataT]] - history_processors: Sequence[HistoryProcessor[DepsT]] + history_processors: HistoryProcessors[DepsT] function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False) mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False) @@ -358,9 +369,7 @@ async def _stream( model_settings, model_request_parameters = await self._prepare_request(ctx) model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters) - message_history = await _process_message_history( - ctx.state.message_history, ctx.deps.history_processors, build_run_context(ctx) - ) + message_history = await _process_message_history(ctx.state, ctx.deps.history_processors, build_run_context(ctx)) async with ctx.deps.model.request_stream( message_history, model_settings, model_request_parameters ) as streamed_response: @@ -384,9 +393,7 @@ async def _make_request( model_settings, model_request_parameters = await self._prepare_request(ctx) model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters) - message_history = await _process_message_history( - ctx.state.message_history, ctx.deps.history_processors, build_run_context(ctx) - ) + message_history = await _process_message_history(ctx.state, ctx.deps.history_processors, build_run_context(ctx)) model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters) ctx.state.usage.incr(_usage.Usage()) @@ -955,12 +962,13 @@ def build_agent_graph( async def _process_message_history( - messages: list[_messages.ModelMessage], - processors: Sequence[HistoryProcessor[DepsT]], + state: GraphAgentState, + processors: HistoryProcessors[DepsT], run_context: RunContext[DepsT], ) -> list[_messages.ModelMessage]: """Process message history through a sequence of processors.""" - for processor in processors: + messages = state.message_history + for processor in processors.funcs: takes_ctx = is_takes_ctx(processor) if is_async_callable(processor): @@ -976,4 +984,6 @@ async def _process_message_history( else: sync_processor = cast(_HistoryProcessorSync, processor) messages = await run_in_executor(sync_processor, messages) + if processors.replace_history: + state.message_history = messages return messages diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 53e8416c0..b3eeabf0d 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -30,7 +30,7 @@ result, usage as _usage, ) -from ._agent_graph import HistoryProcessor +from ._agent_graph import HistoryProcessor, HistoryProcessors from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model from .output import OutputDataT, OutputSpec from .result import FinalResult, StreamedRunResult @@ -78,6 +78,7 @@ 'ModelRequestNode', 'UserPromptNode', 'InstrumentationSettings', + 'HistoryProcessors', ) @@ -181,7 +182,7 @@ def __init__( defer_model_check: bool = False, end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, - history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, + history_processors: Sequence[HistoryProcessor[AgentDepsT]] | HistoryProcessors[AgentDepsT] | None = None, ) -> None: ... @overload @@ -211,7 +212,7 @@ def __init__( defer_model_check: bool = False, end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, - history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, + history_processors: Sequence[HistoryProcessor[AgentDepsT]] | HistoryProcessors[AgentDepsT] | None = None, ) -> None: ... def __init__( @@ -236,7 +237,7 @@ def __init__( defer_model_check: bool = False, end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, - history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, + history_processors: Sequence[HistoryProcessor[AgentDepsT]] | HistoryProcessors[AgentDepsT] | None = None, **_deprecated_kwargs: Any, ): """Create an agent. @@ -359,7 +360,13 @@ def __init__( self._max_result_retries = output_retries if output_retries is not None else retries self._mcp_servers = mcp_servers self._prepare_tools = prepare_tools - self.history_processors = history_processors or [] + history_processors = history_processors or [] + self.history_processors = cast( + HistoryProcessors[AgentDepsT], + HistoryProcessors(funcs=list(history_processors)) + if not isinstance(history_processors, HistoryProcessors) + else history_processors, + ) for tool in tools: if isinstance(tool, Tool): self._register_tool(tool) diff --git a/tests/test_history_processor.py b/tests/test_history_processor.py index 1a1e2ffa3..d7db9ac61 100644 --- a/tests/test_history_processor.py +++ b/tests/test_history_processor.py @@ -4,7 +4,7 @@ import pytest from inline_snapshot import snapshot -from pydantic_ai import Agent +from pydantic_ai import Agent, HistoryProcessors from pydantic_ai.messages import ModelMessage, ModelRequest, ModelRequestPart, ModelResponse, TextPart, UserPromptPart from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.tools import RunContext @@ -301,3 +301,111 @@ class Deps: user_part = msg.parts[0] assert isinstance(user_part, UserPromptPart) assert cast(str, user_part.content).startswith('TEST: ') + + +async def test_history_processors_replace_history_true(function_model: FunctionModel): + """Test HistoryProcessors with replace_history=True modifies original history.""" + + def keep_only_requests(messages: list[ModelMessage]) -> list[ModelMessage]: + return [msg for msg in messages if isinstance(msg, ModelRequest)] + + processors = HistoryProcessors(funcs=[keep_only_requests], replace_history=True) # type: ignore + agent = Agent(function_model, history_processors=processors) # type: ignore + + original_history = [ + ModelRequest(parts=[UserPromptPart(content='Question 1')]), + ModelResponse(parts=[TextPart(content='Answer 1')]), + ModelRequest(parts=[UserPromptPart(content='Question 2')]), + ModelResponse(parts=[TextPart(content='Answer 2')]), + ] + + result = await agent.run('Question 3', message_history=original_history.copy()) + + # Verify the history was modified - responses should be removed + all_messages = result.all_messages() + requests = [msg for msg in all_messages if isinstance(msg, ModelRequest)] + responses = [msg for msg in all_messages if isinstance(msg, ModelResponse)] + + # Should have 3 requests (2 original + 1 new) and 1 response (only the new one) + assert len(requests) == 3 + assert len(responses) == 1 + + +async def test_history_processors_multiple_with_replace_history(function_model: FunctionModel): + """Test multiple processors with replace_history=True.""" + + def remove_responses(messages: list[ModelMessage]) -> list[ModelMessage]: + return [msg for msg in messages if isinstance(msg, ModelRequest)] + + def keep_recent(messages: list[ModelMessage]) -> list[ModelMessage]: + return messages[-2:] if len(messages) > 2 else messages + + processors = HistoryProcessors( # type: ignore + funcs=[remove_responses, keep_recent], replace_history=True + ) + agent = Agent(function_model, history_processors=processors) # type: ignore + + # Create history with 4 requests and 4 responses + original_history: list[ModelMessage] = [] + for i in range(4): + original_history.append(ModelRequest(parts=[UserPromptPart(content=f'Question {i + 1}')])) + original_history.append(ModelResponse(parts=[TextPart(content=f'Answer {i + 1}')])) + + result = await agent.run('Final question', message_history=original_history.copy()) + + # After processing: remove responses -> keep recent 2 -> add new exchange + all_messages = result.all_messages() + requests = [msg for msg in all_messages if isinstance(msg, ModelRequest)] + responses = [msg for msg in all_messages if isinstance(msg, ModelResponse)] + + # Should have 2 requests (1 requests + 1 new) and 1 response (new only), responses should be removed + assert len(requests) == 2 + assert len(responses) == 1 + + +async def test_history_processors_streaming_with_replace_history(function_model: FunctionModel): + """Test replace_history=True works with streaming runs.""" + + def summarize_history(messages: list[ModelMessage]) -> list[ModelMessage]: + # Simple summarization - keep only the last message + return messages[-1:] if messages else [] + + processors = HistoryProcessors(funcs=[summarize_history], replace_history=True) # type: ignore + agent = Agent(function_model, history_processors=processors) # type: ignore + + original_history = [ + ModelRequest(parts=[UserPromptPart(content='Question 1')]), + ModelResponse(parts=[TextPart(content='Answer 1')]), + ModelRequest(parts=[UserPromptPart(content='Question 2')]), + ModelResponse(parts=[TextPart(content='Answer 2')]), + ] + + async with agent.run_stream('Question 3', message_history=original_history.copy()) as result: + async for _ in result.stream_text(): + pass + + # Verify history was modified during streaming + all_messages = result.all_messages() + # Should only have: new request + new response = 2 total + assert len(all_messages) == 2 + + +async def test_history_processors_replace_history_false_default(function_model: FunctionModel): + """Test HistoryProcessors with replace_history=False (default) preserves original history.""" + + def keep_only_requests(messages: list[ModelMessage]) -> list[ModelMessage]: + return [msg for msg in messages if isinstance(msg, ModelRequest)] + + processors = HistoryProcessors(funcs=[keep_only_requests]) # replace_history=False by default # type: ignore + agent = Agent(function_model, history_processors=processors) # type: ignore + + original_history = [ + ModelRequest(parts=[UserPromptPart(content='Question 1')]), + ModelResponse(parts=[TextPart(content='Answer 1')]), + ] + + result = await agent.run('Question 2', message_history=original_history.copy()) + + # Verify original history is preserved + all_messages = result.all_messages() + assert len(all_messages) == 4 # 2 original + 1 new request + 1 new response