diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index 2c29083f..e16261be 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -34,6 +34,7 @@ Literal, NamedTuple, Optional, + overload, ParamSpec, Sequence, Tuple, @@ -457,7 +458,13 @@ def send( class EndpointProperty(Generic[P, R]): - def __init__(self, method: Callable[Concatenate[Any, P], Awaitable[R]]) -> None: + @overload + def __init__(self, method: Callable[Concatenate[Any, P], Awaitable[R]]) -> None: ... + + @overload + def __init__(self, method: Callable[Concatenate[Any, P], R]) -> None: ... + + def __init__(self, method: Any) -> None: self._method = method def __get__(self, instance, owner) -> Endpoint[P, R]: @@ -467,9 +474,19 @@ def __get__(self, instance, owner) -> Endpoint[P, R]: return cast(Endpoint[P, R], self) +@overload def endpoint( method: Callable[Concatenate[Any, P], Awaitable[R]], -) -> EndpointProperty[P, R]: +) -> EndpointProperty[P, R]: ... + + +@overload +def endpoint( + method: Callable[Concatenate[Any, P], R], +) -> EndpointProperty[P, R]: ... + + +def endpoint(method): return EndpointProperty(method)