Skip to content

Commit 33fb34f

Browse files
Improve grpcio serializer/deserializer types (#14093)
1 parent 798f332 commit 33fb34f

File tree

2 files changed

+48
-74
lines changed

2 files changed

+48
-74
lines changed

stubs/grpcio/grpc/__init__.pyi

Lines changed: 30 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,11 @@ from . import aio as aio
1111

1212
__version__: str
1313

14-
# This class encodes an uninhabited type, requiring use of explicit casts or ignores
15-
# in order to satisfy type checkers. This allows grpc-stubs to add proper stubs
16-
# later, allowing those overrides to be removed.
17-
# The alternative is Any, but a future replacement of Any with a proper type
18-
# would result in type errors where previously the type checker was happy, which
19-
# we want to avoid. Forcing the user to use overrides provides forwards-compatibility.
20-
@type_check_only
21-
class _PartialStubMustCastOrIgnore: ...
14+
_T = TypeVar("_T")
2215

2316
# XXX: Early attempts to tame this used literals for all the keys (gRPC is
2417
# a bit segfaulty and doesn't adequately validate the option keys), but that
25-
# didn't quite work out. Maybe it's something we can come back to?
18+
# didn't quite work out. Maybe it's something we can come back to
2619
_OptionKeyValue: TypeAlias = tuple[str, Any]
2720
_Options: TypeAlias = Sequence[_OptionKeyValue]
2821

@@ -45,24 +38,8 @@ _Metadata: TypeAlias = tuple[tuple[str, str | bytes], ...]
4538

4639
_TRequest = TypeVar("_TRequest")
4740
_TResponse = TypeVar("_TResponse")
48-
49-
# XXX: These are probably the SerializeToTring/FromString pb2 methods, but
50-
# this needs further investigation
51-
@type_check_only
52-
class _RequestSerializer(Protocol):
53-
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
54-
55-
@type_check_only
56-
class _RequestDeserializer(Protocol):
57-
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
58-
59-
@type_check_only
60-
class _ResponseSerializer(Protocol):
61-
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
62-
63-
@type_check_only
64-
class _ResponseDeserializer(Protocol):
65-
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
41+
_Serializer: TypeAlias = Callable[[_T], bytes]
42+
_Deserializer: TypeAlias = Callable[[bytes], _T]
6643

6744
# Future Interfaces:
6845

@@ -176,24 +153,24 @@ class _Behaviour(Protocol):
176153

177154
def unary_unary_rpc_method_handler(
178155
behavior: _Behaviour,
179-
request_deserializer: _RequestDeserializer | None = None,
180-
response_serializer: _ResponseSerializer | None = None,
181-
) -> RpcMethodHandler[Any, Any]: ...
156+
request_deserializer: _Deserializer[_TRequest] | None = None,
157+
response_serializer: _Serializer[_TResponse] | None = None,
158+
) -> RpcMethodHandler[_TRequest, _TResponse]: ...
182159
def unary_stream_rpc_method_handler(
183160
behavior: _Behaviour,
184-
request_deserializer: _RequestDeserializer | None = None,
185-
response_serializer: _ResponseSerializer | None = None,
186-
) -> RpcMethodHandler[Any, Any]: ...
161+
request_deserializer: _Deserializer[_TRequest] | None = None,
162+
response_serializer: _Serializer[_TResponse] | None = None,
163+
) -> RpcMethodHandler[_TRequest, _TResponse]: ...
187164
def stream_unary_rpc_method_handler(
188165
behavior: _Behaviour,
189-
request_deserializer: _RequestDeserializer | None = None,
190-
response_serializer: _ResponseSerializer | None = None,
191-
) -> RpcMethodHandler[Any, Any]: ...
166+
request_deserializer: _Deserializer[_TRequest] | None = None,
167+
response_serializer: _Serializer[_TResponse] | None = None,
168+
) -> RpcMethodHandler[_TRequest, _TResponse]: ...
192169
def stream_stream_rpc_method_handler(
193170
behavior: _Behaviour,
194-
request_deserializer: _RequestDeserializer | None = None,
195-
response_serializer: _ResponseSerializer | None = None,
196-
) -> RpcMethodHandler[Any, Any]: ...
171+
request_deserializer: _Deserializer[_TRequest] | None = None,
172+
response_serializer: _Serializer[_TResponse] | None = None,
173+
) -> RpcMethodHandler[_TRequest, _TResponse]: ...
197174
def method_handlers_generic_handler(
198175
service: str, method_handlers: dict[str, RpcMethodHandler[Any, Any]]
199176
) -> GenericRpcHandler[Any, Any]: ...
@@ -250,32 +227,32 @@ class Channel(abc.ABC):
250227
def stream_stream(
251228
self,
252229
method: str,
253-
request_serializer: _RequestSerializer | None = None,
254-
response_deserializer: _ResponseDeserializer | None = None,
255-
) -> StreamStreamMultiCallable[Any, Any]: ...
230+
request_serializer: _Serializer[_TRequest] | None = None,
231+
response_deserializer: _Deserializer[_TResponse] | None = None,
232+
) -> StreamStreamMultiCallable[_TRequest, _TResponse]: ...
256233
@abc.abstractmethod
257234
def stream_unary(
258235
self,
259236
method: str,
260-
request_serializer: _RequestSerializer | None = None,
261-
response_deserializer: _ResponseDeserializer | None = None,
262-
) -> StreamUnaryMultiCallable[Any, Any]: ...
237+
request_serializer: _Serializer[_TRequest] | None = None,
238+
response_deserializer: _Deserializer[_TResponse] | None = None,
239+
) -> StreamUnaryMultiCallable[_TRequest, _TResponse]: ...
263240
@abc.abstractmethod
264241
def subscribe(self, callback: Callable[[ChannelConnectivity], None], try_to_connect: bool = False) -> None: ...
265242
@abc.abstractmethod
266243
def unary_stream(
267244
self,
268245
method: str,
269-
request_serializer: _RequestSerializer | None = None,
270-
response_deserializer: _ResponseDeserializer | None = None,
271-
) -> UnaryStreamMultiCallable[Any, Any]: ...
246+
request_serializer: _Serializer[_TRequest] | None = None,
247+
response_deserializer: _Deserializer[_TResponse] | None = None,
248+
) -> UnaryStreamMultiCallable[_TRequest, _TResponse]: ...
272249
@abc.abstractmethod
273250
def unary_unary(
274251
self,
275252
method: str,
276-
request_serializer: _RequestSerializer | None = None,
277-
response_deserializer: _ResponseDeserializer | None = None,
278-
) -> UnaryUnaryMultiCallable[Any, Any]: ...
253+
request_serializer: _Serializer[_TRequest] | None = None,
254+
response_deserializer: _Deserializer[_TResponse] | None = None,
255+
) -> UnaryUnaryMultiCallable[_TRequest, _TResponse]: ...
279256
@abc.abstractmethod
280257
def unsubscribe(self, callback: Callable[[ChannelConnectivity], None]) -> None: ...
281258
def __enter__(self) -> Self: ...
@@ -499,10 +476,10 @@ class RpcMethodHandler(abc.ABC, Generic[_TRequest, _TResponse]):
499476
response_streaming: bool
500477

501478
# XXX: not clear from docs whether this is optional or not
502-
request_deserializer: _RequestDeserializer | None
479+
request_deserializer: _Deserializer[_TRequest] | None
503480

504481
# XXX: not clear from docs whether this is optional or not
505-
response_serializer: _ResponseSerializer | None
482+
response_serializer: _Serializer[_TResponse] | None
506483

507484
unary_unary: Callable[[_TRequest, ServicerContext], _TResponse] | None
508485

stubs/grpcio/grpc/aio/__init__.pyi

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,8 @@ def server(
7474

7575
# Channel Object:
7676

77-
# XXX: The docs suggest these type signatures for aio, but not for non-async,
78-
# and it's unclear why;
79-
# https://grpc.github.io/grpc/python/grpc_asyncio.html#grpc.aio.Channel.stream_stream
80-
_RequestSerializer: TypeAlias = Callable[[Any], bytes]
81-
_ResponseDeserializer: TypeAlias = Callable[[bytes], Any]
77+
_Serializer: TypeAlias = Callable[[_T], bytes]
78+
_Deserializer: TypeAlias = Callable[[bytes], _T]
8279

8380
class Channel(abc.ABC):
8481
@abc.abstractmethod
@@ -91,30 +88,30 @@ class Channel(abc.ABC):
9188
def stream_stream(
9289
self,
9390
method: str,
94-
request_serializer: _RequestSerializer | None = None,
95-
response_deserializer: _ResponseDeserializer | None = None,
96-
) -> StreamStreamMultiCallable[Any, Any]: ...
91+
request_serializer: _Serializer[_TRequest] | None = None,
92+
response_deserializer: _Deserializer[_TResponse] | None = None,
93+
) -> StreamStreamMultiCallable[_TRequest, _TResponse]: ...
9794
@abc.abstractmethod
9895
def stream_unary(
9996
self,
10097
method: str,
101-
request_serializer: _RequestSerializer | None = None,
102-
response_deserializer: _ResponseDeserializer | None = None,
103-
) -> StreamUnaryMultiCallable[Any, Any]: ...
98+
request_serializer: _Serializer[_TRequest] | None = None,
99+
response_deserializer: _Deserializer[_TResponse] | None = None,
100+
) -> StreamUnaryMultiCallable[_TRequest, _TResponse]: ...
104101
@abc.abstractmethod
105102
def unary_stream(
106103
self,
107104
method: str,
108-
request_serializer: _RequestSerializer | None = None,
109-
response_deserializer: _ResponseDeserializer | None = None,
110-
) -> UnaryStreamMultiCallable[Any, Any]: ...
105+
request_serializer: _Serializer[_TRequest] | None = None,
106+
response_deserializer: _Deserializer[_TResponse] | None = None,
107+
) -> UnaryStreamMultiCallable[_TRequest, _TResponse]: ...
111108
@abc.abstractmethod
112109
def unary_unary(
113110
self,
114111
method: str,
115-
request_serializer: _RequestSerializer | None = None,
116-
response_deserializer: _ResponseDeserializer | None = None,
117-
) -> UnaryUnaryMultiCallable[Any, Any]: ...
112+
request_serializer: _Serializer[_TRequest] | None = None,
113+
response_deserializer: _Deserializer[_TResponse] | None = None,
114+
) -> UnaryUnaryMultiCallable[_TRequest, _TResponse]: ...
118115
@abc.abstractmethod
119116
async def __aenter__(self) -> Self: ...
120117
@abc.abstractmethod
@@ -299,8 +296,8 @@ class InterceptedUnaryUnaryCall(_InterceptedCall[_TRequest, _TResponse], metacla
299296
wait_for_ready: bool | None,
300297
channel: Channel,
301298
method: bytes,
302-
request_serializer: _RequestSerializer,
303-
response_deserializer: _ResponseDeserializer,
299+
request_serializer: _Serializer[_TRequest],
300+
response_deserializer: _Deserializer[_TResponse],
304301
loop: asyncio.AbstractEventLoop,
305302
) -> None: ...
306303

@@ -314,8 +311,8 @@ class InterceptedUnaryUnaryCall(_InterceptedCall[_TRequest, _TResponse], metacla
314311
credentials: CallCredentials | None,
315312
wait_for_ready: bool | None,
316313
request: _TRequest,
317-
request_serializer: _RequestSerializer,
318-
response_deserializer: _ResponseDeserializer,
314+
request_serializer: _Serializer[_TRequest],
315+
response_deserializer: _Deserializer[_TResponse],
319316
) -> UnaryUnaryCall[_TRequest, _TResponse]: ...
320317
def time_remaining(self) -> float | None: ...
321318

0 commit comments

Comments
 (0)