Skip to content

Commit 32ea655

Browse files
committed
Fix non-async endpoint type checking
It wasn't working correctly if the function didn't return an Awaitable. Differential Revision: [D78181976](https://our.internmc.facebook.com/intern/diff/D78181976/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D78181976/)! ghstack-source-id: 295713286 Pull Request resolved: #510
1 parent d54345a commit 32ea655

File tree

1 file changed

+19
-2
lines changed

1 file changed

+19
-2
lines changed

python/monarch/_src/actor/actor_mesh.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
Literal,
3535
NamedTuple,
3636
Optional,
37+
overload,
3738
ParamSpec,
3839
Sequence,
3940
Tuple,
@@ -457,7 +458,13 @@ def send(
457458

458459

459460
class EndpointProperty(Generic[P, R]):
460-
def __init__(self, method: Callable[Concatenate[Any, P], Awaitable[R]]) -> None:
461+
@overload
462+
def __init__(self, method: Callable[Concatenate[Any, P], Awaitable[R]]) -> None: ...
463+
464+
@overload
465+
def __init__(self, method: Callable[Concatenate[Any, P], R]) -> None: ...
466+
467+
def __init__(self, method: Any) -> None:
461468
self._method = method
462469

463470
def __get__(self, instance, owner) -> Endpoint[P, R]:
@@ -467,9 +474,19 @@ def __get__(self, instance, owner) -> Endpoint[P, R]:
467474
return cast(Endpoint[P, R], self)
468475

469476

477+
@overload
470478
def endpoint(
471479
method: Callable[Concatenate[Any, P], Awaitable[R]],
472-
) -> EndpointProperty[P, R]:
480+
) -> EndpointProperty[P, R]: ...
481+
482+
483+
@overload
484+
def endpoint(
485+
method: Callable[Concatenate[Any, P], R],
486+
) -> EndpointProperty[P, R]: ...
487+
488+
489+
def endpoint(method):
473490
return EndpointProperty(method)
474491

475492

0 commit comments

Comments
 (0)