Skip to content

Commit 6e6fee9

Browse files
pandersen-unchartedseanw7DouweM
authored
Fix parallel tool calls on Bedrock with Nova and Claude models (#1656)
Co-authored-by: seanw7 <seanw@protonmail.ch> Co-authored-by: Douwe Maan <douwe@pydantic.dev>
1 parent 1282f32 commit 6e6fee9

File tree

2 files changed

+72
-8
lines changed

2 files changed

+72
-8
lines changed

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -367,13 +367,16 @@ def _map_inference_config(
367367
async def _map_messages(
368368
self, messages: list[ModelMessage]
369369
) -> tuple[list[SystemContentBlockTypeDef], list[MessageUnionTypeDef]]:
370-
"""Just maps a `pydantic_ai.Message` to the Bedrock `MessageUnionTypeDef`."""
370+
"""Maps a `pydantic_ai.Message` to the Bedrock `MessageUnionTypeDef`.
371+
372+
Groups consecutive ToolReturnPart objects into a single user message as required by Bedrock Claude/Nova models.
373+
"""
371374
system_prompt: list[SystemContentBlockTypeDef] = []
372375
bedrock_messages: list[MessageUnionTypeDef] = []
373376
document_count: Iterator[int] = count(1)
374-
for m in messages:
375-
if isinstance(m, ModelRequest):
376-
for part in m.parts:
377+
for message in messages:
378+
if isinstance(message, ModelRequest):
379+
for part in message.parts:
377380
if isinstance(part, SystemPromptPart):
378381
system_prompt.append({'text': part.content})
379382
elif isinstance(part, UserPromptPart):
@@ -414,22 +417,41 @@ async def _map_messages(
414417
],
415418
}
416419
)
417-
elif isinstance(m, ModelResponse):
420+
elif isinstance(message, ModelResponse):
418421
content: list[ContentBlockOutputTypeDef] = []
419-
for item in m.parts:
422+
for item in message.parts:
420423
if isinstance(item, TextPart):
421424
content.append({'text': item.content})
422425
else:
423426
assert isinstance(item, ToolCallPart)
424427
content.append(self._map_tool_call(item))
425428
bedrock_messages.append({'role': 'assistant', 'content': content})
426429
else:
427-
assert_never(m)
430+
assert_never(message)
431+
432+
# Merge together sequential user messages.
433+
processed_messages: list[MessageUnionTypeDef] = []
434+
last_message: dict[str, Any] | None = None
435+
for current_message in bedrock_messages:
436+
if (
437+
last_message is not None
438+
and current_message['role'] == last_message['role']
439+
and current_message['role'] == 'user'
440+
):
441+
# Add the new user content onto the existing user message.
442+
last_content = list(last_message['content'])
443+
last_content.extend(current_message['content'])
444+
last_message['content'] = last_content
445+
continue
446+
447+
# Add the entire message to the list of messages.
448+
processed_messages.append(current_message)
449+
last_message = cast(dict[str, Any], current_message)
428450

429451
if instructions := self._get_instructions(messages):
430452
system_prompt.insert(0, {'text': instructions})
431453

432-
return system_prompt, bedrock_messages
454+
return system_prompt, processed_messages
433455

434456
@staticmethod
435457
async def _map_user_prompt(part: UserPromptPart, document_count: Iterator[int]) -> list[MessageUnionTypeDef]:

tests/models/test_bedrock.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,3 +581,45 @@ async def test_bedrock_multiple_documents_in_history(
581581
assert result.output == snapshot(
582582
'Based on the documents you\'ve shared, both Document 1.pdf and Document 2.pdf contain the text "Dummy PDF file". These appear to be placeholder or sample PDF documents rather than files with substantial content.'
583583
)
584+
585+
586+
async def test_bedrock_group_consecutive_tool_return_parts(bedrock_provider: BedrockProvider):
587+
"""
588+
Test that consecutive ToolReturnPart objects are grouped into a single user message for Bedrock.
589+
"""
590+
model = BedrockConverseModel('us.amazon.nova-micro-v1:0', provider=bedrock_provider)
591+
now = datetime.datetime.now()
592+
# Create a ModelRequest with 3 consecutive ToolReturnParts
593+
req = [
594+
ModelRequest(parts=[UserPromptPart(content=['Hello'])]),
595+
ModelResponse(parts=[TextPart(content='Hi')]),
596+
ModelRequest(parts=[UserPromptPart(content=['How are you?'])]),
597+
ModelResponse(parts=[TextPart(content='Cloudy')]),
598+
ModelRequest(
599+
parts=[
600+
ToolReturnPart(tool_name='tool1', content='result1', tool_call_id='id1', timestamp=now),
601+
ToolReturnPart(tool_name='tool2', content='result2', tool_call_id='id2', timestamp=now),
602+
ToolReturnPart(tool_name='tool3', content='result3', tool_call_id='id3', timestamp=now),
603+
]
604+
),
605+
]
606+
607+
# Call the mapping function directly
608+
_, bedrock_messages = await model._map_messages(req) # type: ignore[reportPrivateUsage]
609+
610+
assert bedrock_messages == snapshot(
611+
[
612+
{'role': 'user', 'content': [{'text': 'Hello'}]},
613+
{'role': 'assistant', 'content': [{'text': 'Hi'}]},
614+
{'role': 'user', 'content': [{'text': 'How are you?'}]},
615+
{'role': 'assistant', 'content': [{'text': 'Cloudy'}]},
616+
{
617+
'role': 'user',
618+
'content': [
619+
{'toolResult': {'toolUseId': 'id1', 'content': [{'text': 'result1'}], 'status': 'success'}},
620+
{'toolResult': {'toolUseId': 'id2', 'content': [{'text': 'result2'}], 'status': 'success'}},
621+
{'toolResult': {'toolUseId': 'id3', 'content': [{'text': 'result3'}], 'status': 'success'}},
622+
],
623+
},
624+
]
625+
)

0 commit comments

Comments
 (0)