Skip to content

Commit aae9022

Browse files
authored
Return last text parts on empty message (#1408)
1 parent a45b8e1 commit aae9022

File tree

2 files changed

+69
-1
lines changed

2 files changed

+69
-1
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,18 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
427427
# No events are emitted during the handling of text responses, so we don't need to yield anything
428428
self._next_node = await self._handle_text_response(ctx, texts)
429429
else:
430+
# we've got an empty response, this sometimes happens with anthropic (and perhaps other models)
431+
# when the model has already returned text along side tool calls
432+
# in this scenario, if text responses are allowed, we return text from the most recent model
433+
# response, if any
434+
if allow_text_output(ctx.deps.output_schema):
435+
for message in reversed(ctx.state.message_history):
436+
if isinstance(message, _messages.ModelResponse):
437+
last_texts = [p.content for p in message.parts if isinstance(p, _messages.TextPart)]
438+
if last_texts:
439+
self._next_node = await self._handle_text_response(ctx, last_texts)
440+
return
441+
430442
raise exceptions.UnexpectedModelBehavior('Received empty model response')
431443

432444
self._events_iterator = _run_stream()
@@ -530,14 +542,14 @@ async def _handle_text_response(
530542

531543
text = '\n\n'.join(texts)
532544
if allow_text_output(output_schema):
545+
# The following cast is safe because we know `str` is an allowed result type
533546
result_data_input = cast(NodeRunEndT, text)
534547
try:
535548
result_data = await _validate_output(result_data_input, ctx, None)
536549
except _output.ToolRetryError as e:
537550
ctx.state.increment_retries(ctx.deps.max_result_retries)
538551
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
539552
else:
540-
# The following cast is safe because we know `str` is an allowed result type
541553
return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), [])
542554
else:
543555
ctx.state.increment_retries(ctx.deps.max_result_retries)

tests/test_agent.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1813,3 +1813,59 @@ def test_instructions_with_message_history():
18131813
),
18141814
]
18151815
)
1816+
1817+
1818+
def test_empty_final_response():
1819+
def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
1820+
if len(messages) == 1:
1821+
return ModelResponse(parts=[TextPart('foo'), ToolCallPart('my_tool', {'x': 1})])
1822+
elif len(messages) == 3:
1823+
return ModelResponse(parts=[TextPart('bar'), ToolCallPart('my_tool', {'x': 2})])
1824+
else:
1825+
return ModelResponse(parts=[])
1826+
1827+
agent = Agent(FunctionModel(llm))
1828+
1829+
@agent.tool_plain
1830+
def my_tool(x: int) -> int:
1831+
return x * 2
1832+
1833+
result = agent.run_sync('Hello')
1834+
assert result.output == 'bar'
1835+
1836+
assert result.new_messages() == snapshot(
1837+
[
1838+
ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
1839+
ModelResponse(
1840+
parts=[
1841+
TextPart(content='foo'),
1842+
ToolCallPart(tool_name='my_tool', args={'x': 1}, tool_call_id=IsStr()),
1843+
],
1844+
model_name='function:llm:',
1845+
timestamp=IsNow(tz=timezone.utc),
1846+
),
1847+
ModelRequest(
1848+
parts=[
1849+
ToolReturnPart(
1850+
tool_name='my_tool', content=2, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc)
1851+
)
1852+
]
1853+
),
1854+
ModelResponse(
1855+
parts=[
1856+
TextPart(content='bar'),
1857+
ToolCallPart(tool_name='my_tool', args={'x': 2}, tool_call_id=IsStr()),
1858+
],
1859+
model_name='function:llm:',
1860+
timestamp=IsNow(tz=timezone.utc),
1861+
),
1862+
ModelRequest(
1863+
parts=[
1864+
ToolReturnPart(
1865+
tool_name='my_tool', content=4, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc)
1866+
)
1867+
]
1868+
),
1869+
ModelResponse(parts=[], model_name='function:llm:', timestamp=IsNow(tz=timezone.utc)),
1870+
]
1871+
)

0 commit comments

Comments
 (0)