Skip to content

Commit 352acff

Browse files
authored
Proper check if callable is async (#1972)
1 parent 67c381e commit 352acff

File tree

6 files changed

+58
-11
lines changed

6 files changed

+58
-11
lines changed

pydantic_ai_slim/pydantic_ai/_function_schema.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from __future__ import annotations as _annotations
77

8-
import inspect
98
from collections.abc import Awaitable
109
from dataclasses import dataclass, field
1110
from inspect import Parameter, signature
@@ -23,7 +22,7 @@
2322
from pydantic_ai.tools import RunContext
2423

2524
from ._griffe import doc_descriptions
26-
from ._utils import check_object_json_schema, is_model_like, run_in_executor
25+
from ._utils import check_object_json_schema, is_async_callable, is_model_like, run_in_executor
2726

2827
if TYPE_CHECKING:
2928
from .tools import DocstringFormat, ObjectJsonSchema
@@ -214,7 +213,7 @@ def function_schema( # noqa: C901
214213
positional_fields=positional_fields,
215214
var_positional_field=var_positional_field,
216215
takes_ctx=takes_ctx,
217-
is_async=inspect.iscoroutinefunction(function),
216+
is_async=is_async_callable(function),
218217
function=function,
219218
)
220219

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class OutputValidator(Generic[AgentDepsT, OutputDataT_inv]):
6060

6161
def __post_init__(self):
6262
self._takes_ctx = len(inspect.signature(self.function).parameters) > 1
63-
self._is_async = inspect.iscoroutinefunction(self.function)
63+
self._is_async = _utils.is_async_callable(self.function)
6464

6565
async def validate(
6666
self,

pydantic_ai_slim/pydantic_ai/_system_prompt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class SystemPromptRunner(Generic[AgentDepsT]):
1818

1919
def __post_init__(self):
2020
self._takes_ctx = len(inspect.signature(self.function).parameters) > 0
21-
self._is_async = inspect.iscoroutinefunction(self.function)
21+
self._is_async = _utils.is_async_callable(self.function)
2222

2323
async def run(self, run_context: RunContext[AgentDepsT]) -> str:
2424
if self._takes_ctx:

pydantic_ai_slim/pydantic_ai/_utils.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
from __future__ import annotations as _annotations
22

33
import asyncio
4+
import functools
5+
import inspect
46
import time
57
import uuid
6-
from collections.abc import AsyncIterable, AsyncIterator, Iterator
8+
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterator
79
from contextlib import asynccontextmanager, suppress
810
from dataclasses import dataclass, fields, is_dataclass
911
from datetime import datetime, timezone
1012
from functools import partial
1113
from types import GenericAlias
12-
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union
14+
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, overload
1315

1416
from anyio.to_thread import run_sync
1517
from pydantic import BaseModel, TypeAdapter
1618
from pydantic.json_schema import JsonSchemaValue
17-
from typing_extensions import ParamSpec, TypeAlias, TypeGuard, is_typeddict
19+
from typing_extensions import ParamSpec, TypeAlias, TypeGuard, TypeIs, is_typeddict
1820

1921
from pydantic_graph._utils import AbstractSpan
2022

@@ -302,3 +304,26 @@ def dataclasses_no_defaults_repr(self: Any) -> str:
302304

303305
def number_to_datetime(x: int | float) -> datetime:
304306
return TypeAdapter(datetime).validate_python(x)
307+
308+
309+
AwaitableCallable = Callable[..., Awaitable[T]]
310+
311+
312+
@overload
313+
def is_async_callable(obj: AwaitableCallable[T]) -> TypeIs[AwaitableCallable[T]]: ...
314+
315+
316+
@overload
317+
def is_async_callable(obj: Any) -> TypeIs[AwaitableCallable[Any]]: ...
318+
319+
320+
def is_async_callable(obj: Any) -> Any:
321+
"""Correctly check if a callable is async.
322+
323+
This function was copied from Starlette:
324+
https://github.com/encode/starlette/blob/78da9b9e218ab289117df7d62aee200ed4c59617/starlette/_utils.py#L36-L40
325+
"""
326+
while isinstance(obj, functools.partial):
327+
obj = obj.func
328+
329+
return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__)) # type: ignore

pydantic_ai_slim/pydantic_ai/tools.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations as _annotations
22

3-
import asyncio
43
import dataclasses
54
import json
65
from collections.abc import Awaitable, Sequence
@@ -337,7 +336,7 @@ def from_schema(
337336
validator=SchemaValidator(schema=core_schema.any_schema()),
338337
json_schema=json_schema,
339338
takes_ctx=False,
340-
is_async=asyncio.iscoroutinefunction(function),
339+
is_async=_utils.is_async_callable(function),
341340
)
342341

343342
return cls(

tests/test_utils.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import contextvars
5+
import functools
56
import os
67
from collections.abc import AsyncIterator
78
from importlib.metadata import distributions
@@ -10,7 +11,14 @@
1011
from inline_snapshot import snapshot
1112

1213
from pydantic_ai import UserError
13-
from pydantic_ai._utils import UNSET, PeekableAsyncStream, check_object_json_schema, group_by_temporal, run_in_executor
14+
from pydantic_ai._utils import (
15+
UNSET,
16+
PeekableAsyncStream,
17+
check_object_json_schema,
18+
group_by_temporal,
19+
is_async_callable,
20+
run_in_executor,
21+
)
1422

1523
from .models.mock_async_stream import MockAsyncStream
1624

@@ -153,3 +161,19 @@ async def test_run_in_executor_with_contextvars() -> None:
153161
# show that the old version did not work
154162
old_result = asyncio.get_running_loop().run_in_executor(None, ctx_var.get)
155163
assert old_result != ctx_var.get()
164+
165+
166+
def test_is_async_callable():
167+
def sync_func(): ... # pragma: no branch
168+
169+
assert is_async_callable(sync_func) is False
170+
171+
async def async_func(): ... # pragma: no branch
172+
173+
assert is_async_callable(async_func) is True
174+
175+
class AsyncCallable:
176+
async def __call__(self): ... # pragma: no branch
177+
178+
partial_async_callable = functools.partial(AsyncCallable())
179+
assert is_async_callable(partial_async_callable) is True

0 commit comments

Comments
 (0)