Skip to content

Commit 20bd1d5

Browse files
committed
fix(ag-ui): partially unknown types
Fix partially known types lint violations in the ag-ui tests due to to conditional imports.
1 parent 38ea590 commit 20bd1d5

File tree

2 files changed

+13
-12
lines changed

2 files changed

+13
-12
lines changed

pydantic_ai_slim/pydantic_ai/models/test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
from contextlib import asynccontextmanager
77
from dataclasses import InitVar, dataclass, field
88
from datetime import date, datetime, timedelta
9-
from typing import Any, Literal, TypeAlias, Union
9+
from typing import Any, Literal, Union
1010

1111
import pydantic_core
12-
from typing_extensions import assert_never
12+
from typing_extensions import TypeAlias, assert_never
1313

1414
from .. import _utils
1515
from ..messages import (

tests/test_ag_ui.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,6 @@
1818
from pydantic import BaseModel
1919

2020
from pydantic_ai import Agent
21-
from pydantic_ai.ag_ui import (
22-
SSE_CONTENT_TYPE,
23-
Adapter,
24-
Role,
25-
StateDeps,
26-
)
2721
from pydantic_ai.models.test import TestModel, TestNode, TestToolCallPart
2822

2923
has_ag_ui: bool = False
@@ -44,6 +38,13 @@
4438
UserMessage,
4539
)
4640

41+
from pydantic_ai.ag_ui import (
42+
SSE_CONTENT_TYPE,
43+
Adapter,
44+
Role,
45+
StateDeps,
46+
)
47+
4748
has_ag_ui = True
4849

4950

@@ -122,7 +123,7 @@ async def create_adapter(
122123
call_tools=call_tools,
123124
tool_call_deltas={'get_weather_parts', 'current_time'},
124125
),
125-
deps_type=StateDeps[StateInt],
126+
deps_type=StateDeps[StateInt], # type: ignore[reportUnknownArgumentType]
126127
tools=[send_snapshot, send_custom, current_time],
127128
)
128129
)
@@ -763,7 +764,7 @@ async def test_run_method(mock_uuid: _MockUUID, tc: AdapterRunTest) -> None:
763764
events: list[str] = []
764765
thread_id: str = f'{THREAD_ID_PREFIX}{mock_uuid()}'
765766
adapter: Adapter[StateDeps[StateInt], str] = await create_adapter(tc.call_tools)
766-
deps: StateDeps[StateInt] = StateDeps[StateInt](state_type=StateInt)
767+
deps: StateDeps[StateInt] = StateDeps[StateInt](state_type=StateInt) # type: ignore[reportUnknownArgumentType]
767768
for run in tc.runs:
768769
if run.nodes is not None:
769770
assert isinstance(adapter.agent.model, TestModel), 'Agent model is not TestModel'
@@ -774,11 +775,11 @@ async def test_run_method(mock_uuid: _MockUUID, tc: AdapterRunTest) -> None:
774775
run_id=f'{RUN_ID_PREFIX}{mock_uuid()}',
775776
)
776777

777-
events.extend([event async for event in adapter.run(run_input, deps=deps)])
778+
events.extend([event async for event in adapter.run(run_input, deps=deps)]) # type: ignore[reportUnknownArgumentType]
778779

779780
assert_events(events, tc.expected_events)
780781
if tc.expected_state is not None:
781-
assert deps.state.value == tc.expected_state
782+
assert deps.state.value == tc.expected_state # type: ignore[reportUnknownArgumentType]
782783

783784

784785
async def test_concurrent_runs(mock_uuid: _MockUUID, adapter: Adapter[None, str]) -> None:

0 commit comments

Comments
 (0)