11
11
from collections .abc import Callable
12
12
from dataclasses import dataclass , field
13
13
from itertools import count
14
- from typing import Any , Final , Literal
14
+ from typing import Any , Final , Literal , cast
15
15
16
16
import pytest
17
+ from pydantic import BaseModel
17
18
18
19
from pydantic_ai import Agent
19
20
from pydantic_ai .models .test import TestModel
40
41
41
42
from adapter_ag_ui ._enums import Role
42
43
from adapter_ag_ui .adapter import AdapterAGUI
44
+ from adapter_ag_ui .deps import StateDeps
43
45
44
46
has_ag_ui = True
45
47
69
71
UUID_PATTERN : Final [re .Pattern [str ]] = re .compile (r'\d{8}-\d{4}-\d{4}-\d{4}-\d{12}' )
70
72
71
73
74
+ class StateInt (BaseModel ):
75
+ """Example state class for testing purposes."""
76
+
77
+ value : int = 0
78
+
79
+
72
80
def get_weather () -> Tool :
73
81
return Tool (
74
82
name = 'get_weather' ,
@@ -87,7 +95,7 @@ def get_weather() -> Tool:
87
95
88
96
89
97
@pytest .fixture
90
- async def adapter () -> AdapterAGUI [None , str ]:
98
+ async def adapter () -> AdapterAGUI [StateDeps [ StateInt ] , str ]:
91
99
"""Fixture to create an AdapterAGUI instance for testing.
92
100
93
101
Returns:
@@ -96,7 +104,7 @@ async def adapter() -> AdapterAGUI[None, str]:
96
104
return await create_adapter ([])
97
105
98
106
99
- async def create_adapter (tools : list [str ] | Literal ['all' ] = 'all' ) -> AdapterAGUI [None , str ]:
107
+ async def create_adapter (tools : list [str ] | Literal ['all' ] = 'all' ) -> AdapterAGUI [StateDeps [ StateInt ] , str ]:
100
108
"""Create an AdapterAGUI instance for testing.
101
109
102
110
Args:
@@ -107,6 +115,7 @@ async def create_adapter(tools: list[str] | Literal['all'] = 'all') -> AdapterAG
107
115
"""
108
116
return Agent (
109
117
model = TestModel (tools ),
118
+ deps_type = cast (type [StateDeps [StateInt ]], StateDeps [StateInt ]),
110
119
tools = [send_snapshot , send_custom ],
111
120
).to_ag_ui ()
112
121
@@ -257,6 +266,7 @@ class AdapterRunTest:
257
266
runs : list [Run ]
258
267
call_tools : list [str ] = field (default_factory = lambda : list [str ]())
259
268
expected_events : list [str ] = field (default_factory = lambda : list (EXPECTED_EVENTS ))
269
+ expected_state : int | None = None
260
270
261
271
262
272
# Test parameter data
@@ -471,6 +481,22 @@ def tc_parameters() -> list[AdapterRunTest]:
471
481
'{"type":"RUN_FINISHED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000002"}' ,
472
482
],
473
483
),
484
+ AdapterRunTest (
485
+ id = 'request_with_state' ,
486
+ runs = [
487
+ Run (
488
+ messages = [ # pyright: ignore[reportArgumentType]
489
+ UserMessage (
490
+ id = 'msg_1' ,
491
+ role = Role .USER .value ,
492
+ content = 'Hello, how are you?' ,
493
+ ),
494
+ ],
495
+ state = {'value' : 42 },
496
+ ),
497
+ ],
498
+ expected_state = 42 ,
499
+ ),
474
500
]
475
501
476
502
@@ -486,16 +512,19 @@ async def test_run_method(mock_uuid: _MockUUID, tc: AdapterRunTest) -> None:
486
512
run : Run
487
513
events : list [str ] = []
488
514
thread_id : str = f'{ THREAD_ID_PREFIX } { mock_uuid ()} '
489
- adapter : AdapterAGUI [None , str ] = await create_adapter (tc .call_tools )
515
+ adapter : AdapterAGUI [StateDeps [StateInt ], str ] = await create_adapter (tc .call_tools )
516
+ deps : StateDeps [StateInt ] = cast (StateDeps [StateInt ], StateDeps [StateInt ](state_type = StateInt ))
490
517
for run in tc .runs :
491
518
run_input : RunAgentInput = run .run_input (
492
519
thread_id = thread_id ,
493
520
run_id = f'{ RUN_ID_PREFIX } { mock_uuid ()} ' ,
494
521
)
495
522
496
- events .extend ([event async for event in adapter .run (run_input )])
523
+ events .extend ([event async for event in adapter .run (run_input , deps = deps )])
497
524
498
525
assert_events (events , tc .expected_events )
526
+ if tc .expected_state is not None :
527
+ assert deps .state .value == tc .expected_state
499
528
500
529
501
530
async def test_concurrent_runs (mock_uuid : _MockUUID , adapter : AdapterAGUI [None , str ]) -> None :
0 commit comments