Skip to content

Commit 25746b3

Browse files
committed
chore(ag-ui): improve test coverage
1 parent 1dee214 commit 25746b3

File tree

4 files changed

+45
-13
lines changed

4 files changed

+45
-13
lines changed

adapter_ag_ui/adapter_ag_ui/_exceptions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __post_init__(self, tool_name: str) -> None:
3232
Args:
3333
tool_name: The name of the tool that was unexpectedly called.
3434
"""
35-
self.message = f'unexpected tool call name={tool_name}'
35+
self.message = f'unexpected tool call name={tool_name}' # pragma: no cover
3636

3737

3838
@dataclass

adapter_ag_ui/adapter_ag_ui/adapter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ async def _tool_events(
316316
if isinstance(item, BaseEvent):
317317
self.logger.debug('ag-ui event: %s', item)
318318
yield encoder.encode(item)
319-
case _:
319+
case _: # pragma: no cover
320320
# Not currently interested in other types.
321321
pass
322322

@@ -354,7 +354,7 @@ def _tool_stub(*args: Any, **kwargs: Any) -> ToolResult:
354354
Raises:
355355
UnexpectedToolCallError: Always raised since this is a stub.
356356
"""
357-
raise UnexpectedToolCallError(tool_name=tool.name)
357+
raise UnexpectedToolCallError(tool_name=tool.name) # pragma: no cover
358358

359359
# TODO(steve): See it we can avoid the cast here.
360360
return cast(
@@ -486,7 +486,7 @@ async def _handle_agent_event(
486486
# TODO(steve): Merge with above list if we don't remove the local_tool_calls case.
487487
if tool_name:
488488
stream_ctx.part_ends.append(None) # Signal continuation of the stream.
489-
case ThinkingPart():
489+
case ThinkingPart(): # pragma: no cover
490490
# No equivalent AG-UI event yet.
491491
pass
492492
case PartDeltaEvent():
@@ -515,7 +515,7 @@ async def _handle_agent_event(
515515
else json.dumps(agent_event.delta.args_delta),
516516
),
517517
)
518-
case ThinkingPartDelta():
518+
case ThinkingPartDelta(): # pragma: no cover
519519
# No equivalent AG-UI event yet.
520520
pass
521521
case FinalResultEvent():

adapter_ag_ui/adapter_ag_ui/deps.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,15 @@ def set_state(self, state: State) -> None:
2727
Implements the `StateHandler` protocol.
2828
2929
Args:
30-
state: The run state.
30+
state: The run state, which should match the expected model type or be `None`.
3131
3232
Raises:
33-
InvalidStateError: If `state` does not match the expected model.
33+
InvalidStateError: If `state` does not match the expected model and is not `None`.
3434
"""
35+
if state is None:
36+
return
37+
3538
try:
3639
self.state = self.state_type.model_validate(state)
37-
except ValidationError as e:
40+
except ValidationError as e: # pragma: no cover
3841
raise InvalidStateError from e

tests/adapter_ag_ui/test_adapter.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
from collections.abc import Callable
1212
from dataclasses import dataclass, field
1313
from itertools import count
14-
from typing import Any, Final, Literal
14+
from typing import Any, Final, Literal, cast
1515

1616
import pytest
17+
from pydantic import BaseModel
1718

1819
from pydantic_ai import Agent
1920
from pydantic_ai.models.test import TestModel
@@ -40,6 +41,7 @@
4041

4142
from adapter_ag_ui._enums import Role
4243
from adapter_ag_ui.adapter import AdapterAGUI
44+
from adapter_ag_ui.deps import StateDeps
4345

4446
has_ag_ui = True
4547

@@ -69,6 +71,12 @@
6971
UUID_PATTERN: Final[re.Pattern[str]] = re.compile(r'\d{8}-\d{4}-\d{4}-\d{4}-\d{12}')
7072

7173

74+
class StateInt(BaseModel):
75+
"""Example state class for testing purposes."""
76+
77+
value: int = 0
78+
79+
7280
def get_weather() -> Tool:
7381
return Tool(
7482
name='get_weather',
@@ -87,7 +95,7 @@ def get_weather() -> Tool:
8795

8896

8997
@pytest.fixture
90-
async def adapter() -> AdapterAGUI[None, str]:
98+
async def adapter() -> AdapterAGUI[StateDeps[StateInt], str]:
9199
"""Fixture to create an AdapterAGUI instance for testing.
92100
93101
Returns:
@@ -96,7 +104,7 @@ async def adapter() -> AdapterAGUI[None, str]:
96104
return await create_adapter([])
97105

98106

99-
async def create_adapter(tools: list[str] | Literal['all'] = 'all') -> AdapterAGUI[None, str]:
107+
async def create_adapter(tools: list[str] | Literal['all'] = 'all') -> AdapterAGUI[StateDeps[StateInt], str]:
100108
"""Create an AdapterAGUI instance for testing.
101109
102110
Args:
@@ -107,6 +115,7 @@ async def create_adapter(tools: list[str] | Literal['all'] = 'all') -> AdapterAG
107115
"""
108116
return Agent(
109117
model=TestModel(tools),
118+
deps_type=cast(type[StateDeps[StateInt]], StateDeps[StateInt]),
110119
tools=[send_snapshot, send_custom],
111120
).to_ag_ui()
112121

@@ -257,6 +266,7 @@ class AdapterRunTest:
257266
runs: list[Run]
258267
call_tools: list[str] = field(default_factory=lambda: list[str]())
259268
expected_events: list[str] = field(default_factory=lambda: list(EXPECTED_EVENTS))
269+
expected_state: int | None = None
260270

261271

262272
# Test parameter data
@@ -471,6 +481,22 @@ def tc_parameters() -> list[AdapterRunTest]:
471481
'{"type":"RUN_FINISHED","threadId":"thread_00000000-0000-0000-0000-000000000001","runId":"run_00000000-0000-0000-0000-000000000002"}',
472482
],
473483
),
484+
AdapterRunTest(
485+
id='request_with_state',
486+
runs=[
487+
Run(
488+
messages=[ # pyright: ignore[reportArgumentType]
489+
UserMessage(
490+
id='msg_1',
491+
role=Role.USER.value,
492+
content='Hello, how are you?',
493+
),
494+
],
495+
state={'value': 42},
496+
),
497+
],
498+
expected_state=42,
499+
),
474500
]
475501

476502

@@ -486,16 +512,19 @@ async def test_run_method(mock_uuid: _MockUUID, tc: AdapterRunTest) -> None:
486512
run: Run
487513
events: list[str] = []
488514
thread_id: str = f'{THREAD_ID_PREFIX}{mock_uuid()}'
489-
adapter: AdapterAGUI[None, str] = await create_adapter(tc.call_tools)
515+
adapter: AdapterAGUI[StateDeps[StateInt], str] = await create_adapter(tc.call_tools)
516+
deps: StateDeps[StateInt] = cast(StateDeps[StateInt], StateDeps[StateInt](state_type=StateInt))
490517
for run in tc.runs:
491518
run_input: RunAgentInput = run.run_input(
492519
thread_id=thread_id,
493520
run_id=f'{RUN_ID_PREFIX}{mock_uuid()}',
494521
)
495522

496-
events.extend([event async for event in adapter.run(run_input)])
523+
events.extend([event async for event in adapter.run(run_input, deps=deps)])
497524

498525
assert_events(events, tc.expected_events)
526+
if tc.expected_state is not None:
527+
assert deps.state.value == tc.expected_state
499528

500529

501530
async def test_concurrent_runs(mock_uuid: _MockUUID, adapter: AdapterAGUI[None, str]) -> None:

0 commit comments

Comments
 (0)