Skip to content

Commit 1a33c02

Browse files
DouweMKludex
andauthored
Properly validate serialized messages with BinaryContent by decoding base64 (#1513)
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
1 parent babdf82 commit 1a33c02

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ def new_event_body():
589589
"""Any message sent to or returned by a model."""
590590

591591
ModelMessagesTypeAdapter = pydantic.TypeAdapter(
592-
list[ModelMessage], config=pydantic.ConfigDict(defer_build=True, ser_json_bytes='base64')
592+
list[ModelMessage], config=pydantic.ConfigDict(defer_build=True, ser_json_bytes='base64', val_json_bytes='base64')
593593
)
594594
"""Pydantic [`TypeAdapter`][pydantic.type_adapter.TypeAdapter] for (de)serializing messages."""
595595

tests/test_agent.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pydantic_ai.messages import (
1616
BinaryContent,
1717
ModelMessage,
18+
ModelMessagesTypeAdapter,
1819
ModelRequest,
1920
ModelResponse,
2021
ModelResponsePart,
@@ -1675,8 +1676,11 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: # pragma: no cover
16751676
def test_binary_content_all_messages_json():
16761677
agent = Agent('test')
16771678

1678-
result = agent.run_sync(['Hello', BinaryContent(data=b'Hello', media_type='text/plain')])
1679-
assert json.loads(result.all_messages_json()) == snapshot(
1679+
content = BinaryContent(data=b'Hello', media_type='text/plain')
1680+
result = agent.run_sync(['Hello', content])
1681+
1682+
serialized = result.all_messages_json()
1683+
assert json.loads(serialized) == snapshot(
16801684
[
16811685
{
16821686
'parts': [
@@ -1698,6 +1702,10 @@ def test_binary_content_all_messages_json():
16981702
]
16991703
)
17001704

1705+
# We also need to be able to round trip the serialized messages.
1706+
messages = ModelMessagesTypeAdapter.validate_json(serialized)
1707+
assert messages == result.all_messages()
1708+
17011709

17021710
def test_instructions_raise_error_when_system_prompt_is_set():
17031711
agent = Agent('test', instructions='An instructions!')

0 commit comments

Comments
 (0)