Skip to content

Commit b7584ab

Browse files
authored
Support streaming tool calls from models that pass args as None when there are no function parameters (#1802)
1 parent 4967685 commit b7584ab

File tree

13 files changed

+99
-36
lines changed

13 files changed

+99
-36
lines changed

docs/agents.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,8 @@ if __name__ == '__main__':
370370
[
371371
'=== UserPromptNode: What will the weather be like in Paris on Tuesday? ===',
372372
'=== ModelRequestNode: streaming partial request tokens ===',
373-
'[Request] Starting part 0: ToolCallPart(tool_name=\'weather_forecast\', args=\'{"location":"Pa\', tool_call_id=\'0001\', part_kind=\'tool-call\')',
373+
"[Request] Starting part 0: ToolCallPart(tool_name='weather_forecast', args=None, tool_call_id='0001', part_kind='tool-call')",
374+
'[Request] Part 0 args_delta={"location":"Pa',
374375
'[Request] Part 0 args_delta=ris","forecast_',
375376
'[Request] Part 0 args_delta=date":"2030-01-',
376377
'[Request] Part 0 args_delta=01"}',

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,13 @@ def validate(
231231
try:
232232
pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
233233
if isinstance(tool_call.args, str):
234-
output = self.type_adapter.validate_json(tool_call.args, experimental_allow_partial=pyd_allow_partial)
234+
output = self.type_adapter.validate_json(
235+
tool_call.args or '{}', experimental_allow_partial=pyd_allow_partial
236+
)
235237
else:
236-
output = self.type_adapter.validate_python(tool_call.args, experimental_allow_partial=pyd_allow_partial)
238+
output = self.type_adapter.validate_python(
239+
tool_call.args or {}, experimental_allow_partial=pyd_allow_partial
240+
)
237241
except ValidationError as e:
238242
if wrap_validation_errors:
239243
m = _messages.RetryPromptPart(

pydantic_ai_slim/pydantic_ai/_parts_manager.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def handle_tool_call_delta(
132132
) -> ModelResponseStreamEvent | None:
133133
"""Handle or update a tool call, creating or updating a `ToolCallPart` or `ToolCallPartDelta`.
134134
135-
Managed items remain as `ToolCallPartDelta`s until they have both a tool_name and arguments, at which
135+
Managed items remain as `ToolCallPartDelta`s until they have at least a tool_name, at which
136136
point they are upgraded to `ToolCallPart`s.
137137
138138
If `vendor_part_id` is None, updates the latest matching ToolCallPart (or ToolCallPartDelta)
@@ -143,11 +143,11 @@ def handle_tool_call_delta(
143143
If None, the latest matching tool call may be updated.
144144
tool_name: The name of the tool. If None, the manager does not enforce
145145
a name match when `vendor_part_id` is None.
146-
args: Arguments for the tool call, either as a string or a dictionary of key-value pairs.
146+
args: Arguments for the tool call, either as a string, a dictionary of key-value pairs, or None.
147147
tool_call_id: An optional string representing an identifier for this tool call.
148148
149149
Returns:
150-
- A `PartStartEvent` if a new (fully realized) ToolCallPart is created.
150+
- A `PartStartEvent` if a new ToolCallPart is created.
151151
- A `PartDeltaEvent` if an existing part is updated.
152152
- `None` if no new event is emitted (e.g., the part is still incomplete).
153153
@@ -207,7 +207,7 @@ def handle_tool_call_part(
207207
*,
208208
vendor_part_id: Hashable | None,
209209
tool_name: str,
210-
args: str | dict[str, Any],
210+
args: str | dict[str, Any] | None,
211211
tool_call_id: str | None = None,
212212
) -> ModelResponseStreamEvent:
213213
"""Immediately create or fully-overwrite a ToolCallPart with the given information.
@@ -218,7 +218,7 @@ def handle_tool_call_part(
218218
vendor_part_id: The vendor's ID for this tool call part. If not
219219
None and an existing part is found, that part is overwritten.
220220
tool_name: The name of the tool being invoked.
221-
args: The arguments for the tool call, either as a string or a dictionary.
221+
args: The arguments for the tool call, either as a string, a dictionary, or None.
222222
tool_call_id: An optional string identifier for this tool call.
223223
224224
Returns:

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ class ToolCallPart:
486486
tool_name: str
487487
"""The name of the tool to call."""
488488

489-
args: str | dict[str, Any]
489+
args: str | dict[str, Any] | None = None
490490
"""The arguments to pass to the tool.
491491
492492
This is stored either as a JSON string or a Python dictionary depending on how data was received.
@@ -506,10 +506,10 @@ def args_as_dict(self) -> dict[str, Any]:
506506
507507
This is just for convenience with models that require dicts as input.
508508
"""
509+
if not self.args:
510+
return {}
509511
if isinstance(self.args, dict):
510512
return self.args
511-
if isinstance(self.args, str) and not self.args:
512-
return {}
513513
args = pydantic_core.from_json(self.args)
514514
assert isinstance(args, dict), 'args should be a dict'
515515
return cast(dict[str, Any], args)
@@ -519,6 +519,8 @@ def args_as_json_str(self) -> str:
519519
520520
This is just for convenience with models that require JSON strings as input.
521521
"""
522+
if not self.args:
523+
return '{}'
522524
if isinstance(self.args, str):
523525
return self.args
524526
return pydantic_core.to_json(self.args).decode()
@@ -666,9 +668,9 @@ def as_part(self) -> ToolCallPart | None:
666668
"""Convert this delta to a fully formed `ToolCallPart` if possible, otherwise return `None`.
667669
668670
Returns:
669-
A `ToolCallPart` if both `tool_name_delta` and `args_delta` are set, otherwise `None`.
671+
A `ToolCallPart` if `tool_name_delta` is set, otherwise `None`.
670672
"""
671-
if self.tool_name_delta is None or self.args_delta is None:
673+
if self.tool_name_delta is None:
672674
return None
673675

674676
return ToolCallPart(self.tool_name_delta, self.args_delta, self.tool_call_id or _generate_tool_call_id())
@@ -728,7 +730,7 @@ def _apply_to_delta(self, delta: ToolCallPartDelta) -> ToolCallPart | ToolCallPa
728730
delta = replace(delta, tool_call_id=self.tool_call_id)
729731

730732
# If we now have enough data to create a full ToolCallPart, do so
731-
if delta.tool_name_delta is not None and delta.args_delta is not None:
733+
if delta.tool_name_delta is not None:
732734
return ToolCallPart(delta.tool_name_delta, delta.args_delta, delta.tool_call_id or _generate_tool_call_id())
733735

734736
return delta
@@ -741,12 +743,12 @@ def _apply_to_part(self, part: ToolCallPart) -> ToolCallPart:
741743
part = replace(part, tool_name=tool_name)
742744

743745
if isinstance(self.args_delta, str):
744-
if not isinstance(part.args, str):
746+
if isinstance(part.args, dict):
745747
raise UnexpectedModelBehavior(f'Cannot apply JSON deltas to non-JSON tool arguments ({part=}, {self=})')
746-
updated_json = part.args + self.args_delta
748+
updated_json = (part.args or '') + self.args_delta
747749
part = replace(part, args=updated_json)
748750
elif isinstance(self.args_delta, dict):
749-
if not isinstance(part.args, dict):
751+
if isinstance(part.args, str):
750752
raise UnexpectedModelBehavior(f'Cannot apply dict deltas to non-dict tool arguments ({part=}, {self=})')
751753
updated_dict = {**(part.args or {}), **self.args_delta}
752754
part = replace(part, args=updated_dict)

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -552,8 +552,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
552552
args=None,
553553
tool_call_id=tool_id,
554554
)
555-
if maybe_event:
556-
yield maybe_event # pragma: no cover
555+
if maybe_event: # pragma: no branch
556+
yield maybe_event
557557
if 'contentBlockDelta' in chunk:
558558
index = chunk['contentBlockDelta']['contentBlockIndex']
559559
delta = chunk['contentBlockDelta']['delta']

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ def _process_response_from_parts(parts: list[Part], model_name: GoogleModelName,
442442
items.append(TextPart(content=part.text))
443443
elif part.function_call:
444444
assert part.function_call.name is not None
445-
tool_call_part = ToolCallPart(tool_name=part.function_call.name, args=part.function_call.args or {})
445+
tool_call_part = ToolCallPart(tool_name=part.function_call.name, args=part.function_call.args)
446446
if part.function_call.id is not None:
447447
tool_call_part.tool_call_id = part.function_call.id # pragma: no cover
448448
items.append(tool_call_part)

pydantic_ai_slim/pydantic_ai/models/mistral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def _map_tool_call(t: ToolCallPart) -> MistralToolCall:
368368
return MistralToolCall(
369369
id=_utils.guard_tool_call_id(t=t),
370370
type='function',
371-
function=MistralFunctionCall(name=t.tool_name, arguments=t.args),
371+
function=MistralFunctionCall(name=t.tool_name, arguments=t.args or {}),
372372
)
373373

374374
def _generate_user_output_format(self, schemas: list[dict[str, Any]]) -> MistralUserMessage:

pydantic_ai_slim/pydantic_ai/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ async def _run(
367367
if isinstance(message.args, str):
368368
args_dict = self._validator.validate_json(message.args or '{}')
369369
else:
370-
args_dict = self._validator.validate_python(message.args)
370+
args_dict = self._validator.validate_python(message.args or {})
371371
except ValidationError as e:
372372
return self._on_error(e, message)
373373

tests/models/test_bedrock.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
TextPart,
2727
TextPartDelta,
2828
ToolCallPart,
29+
ToolCallPartDelta,
2930
ToolReturnPart,
3031
UserPromptPart,
3132
VideoUrl,
@@ -396,10 +397,11 @@ async def get_temperature(city: str) -> str:
396397
PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='thinking')),
397398
PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='>\n')),
398399
PartStartEvent(
400+
index=1, part=ToolCallPart(tool_name='get_temperature', tool_call_id='tooluse_lAG_zP8QRHmSYOwZzzaCqA')
401+
),
402+
PartDeltaEvent(
399403
index=1,
400-
part=ToolCallPart(
401-
tool_name='get_temperature', args='{"city":"Paris"}', tool_call_id='tooluse_lAG_zP8QRHmSYOwZzzaCqA'
402-
),
404+
delta=ToolCallPartDelta(args_delta='{"city":"Paris"}', tool_call_id='tooluse_lAG_zP8QRHmSYOwZzzaCqA'),
403405
),
404406
IsInstance(FunctionToolCallEvent),
405407
FunctionToolResultEvent(

tests/models/test_groq.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,7 @@ async def test_stream_structured(allow_model_requests: None):
432432
assert not result.is_complete
433433
assert [dict(c) async for c in result.stream(debounce_by=None)] == snapshot(
434434
[
435+
{},
435436
{'first': 'One'},
436437
{'first': 'One', 'second': 'Two'},
437438
{'first': 'One', 'second': 'Two'},

0 commit comments

Comments
 (0)