Skip to content

Commit 0810eed

Browse files
authored
Stream tool calls and structured output from Anthropic as it's received instead of in one go (#1669)
1 parent b7584ab commit 0810eed

File tree

2 files changed

+10
-17
lines changed

2 files changed

+10
-17
lines changed

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from contextlib import asynccontextmanager
66
from dataclasses import dataclass, field
77
from datetime import datetime, timezone
8-
from json import JSONDecodeError, loads as json_loads
98
from typing import Any, Literal, Union, cast, overload
109

1110
from typing_extensions import assert_never
@@ -440,7 +439,6 @@ class AnthropicStreamedResponse(StreamedResponse):
440439

441440
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
442441
current_block: ContentBlock | None = None
443-
current_json: str = ''
444442

445443
async for event in self._response:
446444
self._usage += _map_usage(event)
@@ -455,7 +453,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
455453
maybe_event = self._parts_manager.handle_tool_call_delta(
456454
vendor_part_id=current_block.id,
457455
tool_name=current_block.name,
458-
args=cast(dict[str, Any], current_block.input),
456+
args=cast(dict[str, Any], current_block.input) or None,
459457
tool_call_id=current_block.id,
460458
)
461459
if maybe_event is not None: # pragma: no branch
@@ -469,20 +467,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
469467
elif ( # pragma: no branch
470468
current_block and event.delta.type == 'input_json_delta' and isinstance(current_block, ToolUseBlock)
471469
):
472-
# Try to parse the JSON immediately, otherwise cache the value for later. This handles
473-
# cases where the JSON is not currently valid but will be valid once we stream more tokens.
474-
try:
475-
parsed_args = json_loads(current_json + event.delta.partial_json)
476-
current_json = ''
477-
except JSONDecodeError:
478-
current_json += event.delta.partial_json
479-
continue
480-
481-
# For tool calls, we need to handle partial JSON updates
482470
maybe_event = self._parts_manager.handle_tool_call_delta(
483471
vendor_part_id=current_block.id,
484472
tool_name='',
485-
args=parsed_args,
473+
args=event.delta.partial_json,
486474
tool_call_id=current_block.id,
487475
)
488476
if maybe_event is not None: # pragma: no branch

tests/models/test_anthropic.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -583,18 +583,23 @@ async def test_stream_structured(allow_model_requests: None):
583583
RawContentBlockStartEvent(
584584
type='content_block_start',
585585
index=0,
586-
content_block=ToolUseBlock(type='tool_use', id='tool_1', name='my_tool', input={'first': 'One'}),
586+
content_block=ToolUseBlock(type='tool_use', id='tool_1', name='my_tool', input={}),
587587
),
588588
# Add more data through an incomplete JSON delta
589589
RawContentBlockDeltaEvent(
590590
type='content_block_delta',
591591
index=0,
592-
delta=InputJSONDelta(type='input_json_delta', partial_json='{"second":'),
592+
delta=InputJSONDelta(type='input_json_delta', partial_json='{"first": "One'),
593593
),
594594
RawContentBlockDeltaEvent(
595595
type='content_block_delta',
596596
index=0,
597-
delta=InputJSONDelta(type='input_json_delta', partial_json='"Two"}'),
597+
delta=InputJSONDelta(type='input_json_delta', partial_json='", "second": "Two"'),
598+
),
599+
RawContentBlockDeltaEvent(
600+
type='content_block_delta',
601+
index=0,
602+
delta=InputJSONDelta(type='input_json_delta', partial_json='}'),
598603
),
599604
# Mark tool block as complete
600605
RawContentBlockStopEvent(type='content_block_stop', index=0),

0 commit comments

Comments
 (0)