Skip to content

Commit 45401fe

Browse files
committed
feat(TestModel): tool part delta option
Add `tool_call_delta` parameter to `TestModel` to control whether tool responses contain ToolCallPartDelta's or not. This allows consumers to choose between receiving full tool responses or partial updates, enhancing test coverage.
1 parent 25746b3 commit 45401fe

File tree

1 file changed

+34
-5
lines changed
  • pydantic_ai_slim/pydantic_ai/models

1 file changed

+34
-5
lines changed

pydantic_ai_slim/pydantic_ai/models/test.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ class TestModel(Model):
6363

6464
call_tools: list[str] | Literal['all'] = 'all'
6565
"""List of tools to call. If `'all'`, all tools will be called."""
66+
tool_call_deltas: set[str] = field(default_factory=set)
67+
"""A set of tool call names which should result in tool call part deltas."""
6668
custom_output_text: str | None = None
6769
"""If set, this text is returned as the final output."""
6870
custom_output_args: Any | None = None
@@ -102,7 +104,10 @@ async def request_stream(
102104

103105
model_response = self._request(messages, model_settings, model_request_parameters)
104106
yield TestStreamedResponse(
105-
_model_name=self._model_name, _structured_response=model_response, _messages=messages
107+
_model_name=self._model_name,
108+
_structured_response=model_response,
109+
_messages=messages,
110+
_tool_call_deltas=self.tool_call_deltas,
106111
)
107112

108113
@property
@@ -218,7 +223,8 @@ def _request(
218223
output_tool = output_tools[self.seed % len(output_tools)]
219224
if custom_output_args is not None:
220225
return ModelResponse(
221-
parts=[ToolCallPart(output_tool.name, custom_output_args)], model_name=self._model_name
226+
parts=[ToolCallPart(output_tool.name, custom_output_args)],
227+
model_name=self._model_name,
222228
)
223229
else:
224230
response_args = self.gen_tool_args(output_tool)
@@ -232,6 +238,7 @@ class TestStreamedResponse(StreamedResponse):
232238
_model_name: str
233239
_structured_response: ModelResponse
234240
_messages: InitVar[Iterable[ModelMessage]]
241+
_tool_call_deltas: set[str]
235242
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)
236243

237244
def __post_init__(self, _messages: Iterable[ModelMessage]):
@@ -253,9 +260,31 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
253260
self._usage += _get_string_usage(word)
254261
yield self._parts_manager.handle_text_delta(vendor_part_id=i, content=word)
255262
elif isinstance(part, ToolCallPart):
256-
yield self._parts_manager.handle_tool_call_part(
257-
vendor_part_id=i, tool_name=part.tool_name, args=part.args, tool_call_id=part.tool_call_id
258-
)
263+
if part.tool_name in self._tool_call_deltas:
264+
# Start with empty tool call delta.
265+
event = self._parts_manager.handle_tool_call_delta(
266+
vendor_part_id=i, tool_name=part.tool_name, args='', tool_call_id=part.tool_call_id
267+
)
268+
if event is not None:
269+
yield event
270+
271+
# Stream the args as JSON string in chunks.
272+
args_json = pydantic_core.to_json(part.args).decode()
273+
*chunks, last_chunk = args_json.split(',') if ',' in args_json else [args_json]
274+
chunks = [f'{chunk},' for chunk in chunks] if chunks else []
275+
if last_chunk:
276+
chunks.append(last_chunk)
277+
278+
for chunk in chunks:
279+
event = self._parts_manager.handle_tool_call_delta(
280+
vendor_part_id=i, tool_name=None, args=chunk, tool_call_id=part.tool_call_id
281+
)
282+
if event is not None:
283+
yield event
284+
else:
285+
yield self._parts_manager.handle_tool_call_part(
286+
vendor_part_id=i, tool_name=part.tool_name, args=part.args, tool_call_id=part.tool_call_id
287+
)
259288
elif isinstance(part, ThinkingPart): # pragma: no cover
260289
# NOTE: There's no way to reach this part of the code, since we don't generate ThinkingPart on TestModel.
261290
assert False, "This should be unreachable — we don't generate ThinkingPart on TestModel."

0 commit comments

Comments
 (0)