@@ -63,6 +63,8 @@ class TestModel(Model):
63
63
64
64
call_tools : list [str ] | Literal ['all' ] = 'all'
65
65
"""List of tools to call. If `'all'`, all tools will be called."""
66
+ tool_call_deltas : set [str ] = field (default_factory = set )
67
+ """A set of tool call names which should result in tool call part deltas."""
66
68
custom_output_text : str | None = None
67
69
"""If set, this text is returned as the final output."""
68
70
custom_output_args : Any | None = None
@@ -102,7 +104,10 @@ async def request_stream(
102
104
103
105
model_response = self ._request (messages , model_settings , model_request_parameters )
104
106
yield TestStreamedResponse (
105
- _model_name = self ._model_name , _structured_response = model_response , _messages = messages
107
+ _model_name = self ._model_name ,
108
+ _structured_response = model_response ,
109
+ _messages = messages ,
110
+ _tool_call_deltas = self .tool_call_deltas ,
106
111
)
107
112
108
113
@property
@@ -218,7 +223,8 @@ def _request(
218
223
output_tool = output_tools [self .seed % len (output_tools )]
219
224
if custom_output_args is not None :
220
225
return ModelResponse (
221
- parts = [ToolCallPart (output_tool .name , custom_output_args )], model_name = self ._model_name
226
+ parts = [ToolCallPart (output_tool .name , custom_output_args )],
227
+ model_name = self ._model_name ,
222
228
)
223
229
else :
224
230
response_args = self .gen_tool_args (output_tool )
@@ -232,6 +238,7 @@ class TestStreamedResponse(StreamedResponse):
232
238
_model_name : str
233
239
_structured_response : ModelResponse
234
240
_messages : InitVar [Iterable [ModelMessage ]]
241
+ _tool_call_deltas : set [str ]
235
242
_timestamp : datetime = field (default_factory = _utils .now_utc , init = False )
236
243
237
244
def __post_init__ (self , _messages : Iterable [ModelMessage ]):
@@ -253,9 +260,31 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
253
260
self ._usage += _get_string_usage (word )
254
261
yield self ._parts_manager .handle_text_delta (vendor_part_id = i , content = word )
255
262
elif isinstance (part , ToolCallPart ):
256
- yield self ._parts_manager .handle_tool_call_part (
257
- vendor_part_id = i , tool_name = part .tool_name , args = part .args , tool_call_id = part .tool_call_id
258
- )
263
+ if part .tool_name in self ._tool_call_deltas :
264
+ # Start with empty tool call delta.
265
+ event = self ._parts_manager .handle_tool_call_delta (
266
+ vendor_part_id = i , tool_name = part .tool_name , args = '' , tool_call_id = part .tool_call_id
267
+ )
268
+ if event is not None :
269
+ yield event
270
+
271
+ # Stream the args as JSON string in chunks.
272
+ args_json = pydantic_core .to_json (part .args ).decode ()
273
+ * chunks , last_chunk = args_json .split (',' ) if ',' in args_json else [args_json ]
274
+ chunks = [f'{ chunk } ,' for chunk in chunks ] if chunks else []
275
+ if last_chunk :
276
+ chunks .append (last_chunk )
277
+
278
+ for chunk in chunks :
279
+ event = self ._parts_manager .handle_tool_call_delta (
280
+ vendor_part_id = i , tool_name = None , args = chunk , tool_call_id = part .tool_call_id
281
+ )
282
+ if event is not None :
283
+ yield event
284
+ else :
285
+ yield self ._parts_manager .handle_tool_call_part (
286
+ vendor_part_id = i , tool_name = part .tool_name , args = part .args , tool_call_id = part .tool_call_id
287
+ )
259
288
elif isinstance (part , ThinkingPart ): # pragma: no cover
260
289
# NOTE: There's no way to reach this part of the code, since we don't generate ThinkingPart on TestModel.
261
290
assert False , "This should be unreachable — we don't generate ThinkingPart on TestModel."
0 commit comments