Skip to content

Commit 206acb2

Browse files
committed
tests(ag-ui): tool call part delta coverage
Add new tests to cover tool call part delta handling in the adapter.
1 parent 45401fe commit 206acb2

File tree

2 files changed

+159
-11
lines changed

2 files changed

+159
-11
lines changed

adapter_ag_ui/adapter_ag_ui/adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ async def _handle_agent_event(
453453
),
454454
]
455455
if agent_event.part.content:
456-
yield encoder.encode(
456+
yield encoder.encode( # pragma: no cover
457457
TextMessageContentEvent(
458458
type=EventType.TEXT_MESSAGE_CONTENT,
459459
message_id=message_id,

tests/adapter_ag_ui/test_adapter.py

Lines changed: 158 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,9 @@ class StateInt(BaseModel):
7777
value: int = 0
7878

7979

80-
def get_weather() -> Tool:
80+
def get_weather(name: str = 'get_weather') -> Tool:
8181
return Tool(
82-
name='get_weather',
82+
name=name,
8383
description='Get the weather for a given location',
8484
parameters={
8585
'type': 'object',
@@ -114,9 +114,9 @@ async def create_adapter(tools: list[str] | Literal['all'] = 'all') -> AdapterAG
114114
An AdapterAGUI instance configured with the specified tools.
115115
"""
116116
return Agent(
117-
model=TestModel(tools),
117+
model=TestModel(tools, tool_call_deltas={'get_weather_parts', 'current_time'}),
118118
deps_type=cast(type[StateDeps[StateInt]], StateDeps[StateInt]),
119-
tools=[send_snapshot, send_custom],
119+
tools=[send_snapshot, send_custom, current_time],
120120
).to_ag_ui()
121121

122122

@@ -182,6 +182,15 @@ def normalize_uuids(text: str) -> str:
182182
return UUID_PATTERN.sub('00000000-0000-0000-0000-000000000001', text)
183183

184184

185+
def current_time() -> str:
186+
"""Get the current time in ISO format.
187+
188+
Returns:
189+
The current UTC time in ISO format string.
190+
"""
191+
return '21T12:08:45.485981+00:00'
192+
193+
185194
async def send_snapshot() -> StateSnapshotEvent:
186195
"""Display the recipe to the user.
187196
@@ -353,7 +362,7 @@ def tc_parameters() -> list[AdapterRunTest]:
353362
],
354363
),
355364
AdapterRunTest(
356-
id='tool_call',
365+
id='tool_call_ag_ui',
357366
call_tools=['get_weather'],
358367
runs=[
359368
Run(
@@ -411,29 +420,145 @@ def tc_parameters() -> list[AdapterRunTest]:
411420
],
412421
),
413422
AdapterRunTest(
414-
id='tool_call_no_result',
415-
call_tools=['get_weather'],
423+
id='tool_call_ag_ui_multiple',
424+
call_tools=['get_weather', 'get_weather_parts'],
416425
runs=[
426+
Run(
427+
messages=[ # pyright: ignore[reportArgumentType]
428+
UserMessage(
429+
id='msg_1',
430+
role=Role.USER.value,
431+
content='Please call get_weather and get_weather_parts for Paris',
432+
),
433+
],
434+
tools=[get_weather(), get_weather('get_weather_parts')],
435+
),
417436
Run(
418437
messages=[ # pyright: ignore[reportArgumentType]
419438
UserMessage(
420439
id='msg_1',
421440
role=Role.USER.value,
422441
content='Please call get_weather for Paris',
423442
),
443+
AssistantMessage(
444+
id='msg_2',
445+
role=Role.ASSISTANT.value,
446+
tool_calls=[
447+
ToolCall(
448+
id='pyd_ai_00000000000000000000000000000003',
449+
type='function',
450+
function=FunctionCall(
451+
name='get_weather',
452+
arguments='{"location": "Paris"}',
453+
),
454+
),
455+
],
456+
),
457+
ToolMessage(
458+
id='msg_3',
459+
role=Role.TOOL.value,
460+
content='Tool result',
461+
tool_call_id='pyd_ai_00000000000000000000000000000003',
462+
),
463+
AssistantMessage(
464+
id='msg_4',
465+
role=Role.ASSISTANT.value,
466+
tool_calls=[
467+
ToolCall(
468+
id='pyd_ai_00000000000000000000000000000003',
469+
type='function',
470+
function=FunctionCall(
471+
name='get_weather_parts',
472+
arguments='{"location": "Paris"}',
473+
),
474+
),
475+
],
476+
),
477+
ToolMessage(
478+
id='msg_5',
479+
role=Role.TOOL.value,
480+
content='Tool result',
481+
tool_call_id='pyd_ai_00000000000000000000000000000003',
482+
),
424483
],
425-
tools=[get_weather()],
484+
tools=[get_weather(), get_weather('get_weather_parts')],
426485
),
427486
],
428487
expected_events=[
429488
'{"type":"RUN_STARTED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000002"}',
430489
'{"type":"TOOL_CALL_START","toolCallId":"pyd_ai_00000000000000000000000000000003","toolCallName":"get_weather"}',
431490
'{"type":"TOOL_CALL_END","toolCallId":"pyd_ai_00000000000000000000000000000003"}',
432491
'{"type":"RUN_FINISHED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000002"}',
492+
'{"type":"RUN_STARTED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000005"}',
493+
'{"type":"TEXT_MESSAGE_START","messageId":"00000000-0000-0000-0000-000000000006","role":"assistant"}',
494+
'{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000006","delta":"{\\"get_weather\\":\\"Tool "}',
495+
'{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000006","delta":"result\\",\\"get_weather_parts\\":\\"Tool "}',
496+
'{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000006","delta":"result\\"}"}',
497+
'{"type":"TEXT_MESSAGE_END","messageId":"00000000-0000-0000-0000-000000000006"}',
498+
'{"type":"RUN_FINISHED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000005"}',
433499
],
434500
),
435501
AdapterRunTest(
436-
id='tool_single_event',
502+
id='tool_call_ag_ui_parts',
503+
call_tools=['get_weather_parts'],
504+
runs=[
505+
Run(
506+
messages=[ # pyright: ignore[reportArgumentType]
507+
UserMessage(
508+
id='msg_1',
509+
role=Role.USER.value,
510+
content='Please call get_weather_parts for Paris',
511+
),
512+
],
513+
tools=[get_weather('get_weather_parts')],
514+
),
515+
Run(
516+
messages=[ # pyright: ignore[reportArgumentType]
517+
UserMessage(
518+
id='msg_1',
519+
role=Role.USER.value,
520+
content='Please call get_weather_parts for Paris',
521+
),
522+
AssistantMessage(
523+
id='msg_2',
524+
role=Role.ASSISTANT.value,
525+
tool_calls=[
526+
ToolCall(
527+
id='pyd_ai_00000000000000000000000000000003',
528+
type='function',
529+
function=FunctionCall(
530+
name='get_weather_parts',
531+
arguments='{"location": "Paris"}',
532+
),
533+
),
534+
],
535+
),
536+
ToolMessage(
537+
id='msg_3',
538+
role=Role.TOOL.value,
539+
content='Tool result',
540+
tool_call_id='pyd_ai_00000000000000000000000000000003',
541+
),
542+
],
543+
tools=[get_weather('get_weather_parts')],
544+
),
545+
],
546+
expected_events=[
547+
'{"type":"RUN_STARTED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000002"}',
548+
'{"type":"TOOL_CALL_START","toolCallId":"pyd_ai_00000000000000000000000000000003","toolCallName":"get_weather_parts"}',
549+
'{"type":"TOOL_CALL_ARGS","toolCallId":"pyd_ai_00000000000000000000000000000003","delta":"{\\"location\\":\\"a\\"}"}',
550+
'{"type":"TOOL_CALL_END","toolCallId":"pyd_ai_00000000000000000000000000000003"}',
551+
'{"type":"RUN_FINISHED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000002"}',
552+
'{"type":"RUN_STARTED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000004"}',
553+
'{"type":"TEXT_MESSAGE_START","messageId":"00000000-0000-0000-0000-000000000005","role":"assistant"}',
554+
'{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000005","delta":"{\\"get_weather_parts\\":\\"Tool "}',
555+
'{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000005","delta":"result\\"}"}',
556+
'{"type":"TEXT_MESSAGE_END","messageId":"00000000-0000-0000-0000-000000000005"}',
557+
'{"type":"RUN_FINISHED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000004"}',
558+
],
559+
),
560+
AdapterRunTest(
561+
id='tool_local_single_event',
437562
call_tools=['send_snapshot'],
438563
runs=[
439564
Run(
@@ -457,7 +582,7 @@ def tc_parameters() -> list[AdapterRunTest]:
457582
],
458583
),
459584
AdapterRunTest(
460-
id='tool_multiple_events',
585+
id='tool_local_multiple_events',
461586
call_tools=['send_custom'],
462587
runs=[
463588
Run(
@@ -481,6 +606,29 @@ def tc_parameters() -> list[AdapterRunTest]:
481606
'{"type":"RUN_FINISHED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000002"}',
482607
],
483608
),
609+
AdapterRunTest(
610+
id='tool_local_parts',
611+
call_tools=['current_time'],
612+
runs=[
613+
Run(
614+
messages=[ # pyright: ignore[reportArgumentType]
615+
UserMessage(
616+
id='msg_1',
617+
role=Role.USER.value,
618+
content='Please call current_time',
619+
),
620+
],
621+
),
622+
],
623+
expected_events=[
624+
'{"type":"RUN_STARTED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000002"}',
625+
'{"type":"TEXT_MESSAGE_START","messageId":"00000000-0000-0000-0000-000000000004","role":"assistant"}',
626+
'{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000004","delta":"{\\"current_time\\":\\"21T1"}',
627+
'{"type":"TEXT_MESSAGE_CONTENT","messageId":"00000000-0000-0000-0000-000000000004","delta":"2:08:45.485981+00:00\\"}"}',
628+
'{"type":"TEXT_MESSAGE_END","messageId":"00000000-0000-0000-0000-000000000004"}',
629+
'{"type":"RUN_FINISHED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000002"}',
630+
],
631+
),
484632
AdapterRunTest(
485633
id='request_with_state',
486634
runs=[

0 commit comments

Comments
 (0)