Skip to content

Commit 5c3805d

Browse files
committed
improved type hinting
1 parent 5baa2aa commit 5c3805d

File tree

2 files changed

+87
-43
lines changed

2 files changed

+87
-43
lines changed

google/api_core/retry_streaming.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,29 +14,29 @@
1414

1515
"""Helpers for retries for streaming APIs."""
1616

17-
from typing import Callable, Optional, Iterable
17+
from typing import Callable, Optional, Iterable, Iterator, Generator, TypeVar, Any, cast
1818

1919
import datetime
2020
import logging
2121
import time
2222

23-
from collections.abc import Generator
24-
2523
from google.api_core import datetime_helpers
2624
from google.api_core import exceptions
2725

2826
_LOGGER = logging.getLogger(__name__)
2927

28+
T = TypeVar("T")
29+
3030

31-
class RetryableGenerator(Generator):
31+
class RetryableGenerator(Generator[T, Any, None]):
3232
"""
3333
Helper class for retrying Iterator and Generator-based
3434
streaming APIs.
3535
"""
3636

3737
def __init__(
3838
self,
39-
target: Callable[[], Iterable],
39+
target: Callable[[], Iterable[T]],
4040
predicate: Callable[[Exception], bool],
4141
sleep_generator: Iterable[float],
4242
timeout: Optional[float] = None,
@@ -52,13 +52,13 @@ def __init__(
5252
It should return True to retry or False otherwise.
5353
sleep_generator: An infinite iterator that determines
5454
how long to sleep between retries.
55-
timeout: How long to keep retrying the target.
55+
timeout: How long to keep retrying the target, in seconds.
5656
on_error: A function to call while processing a
5757
retryable exception. Any error raised by this function will *not*
5858
be caught.
5959
"""
6060
self.target_fn = target
61-
self.active_target: Iterable = self.target_fn()
61+
self.active_target: Iterator = self.target_fn().__iter__()
6262
self.predicate = predicate
6363
self.sleep_generator = iter(sleep_generator)
6464
self.on_error = on_error
@@ -70,13 +70,13 @@ def __init__(
7070
else:
7171
self.deadline = None
7272

73-
def __iter__(self):
73+
def __iter__(self) -> Generator[T, Any, None]:
7474
"""
7575
Implement the iterator protocol.
7676
"""
7777
return self
7878

79-
def _handle_exception(self, exc):
79+
def _handle_exception(self, exc) -> None:
8080
"""
8181
When an exception is raised while iterating over the active_target,
8282
check if it is retryable. If so, create a new active_target and
@@ -108,9 +108,12 @@ def _handle_exception(self, exc):
108108
time.sleep(next_sleep)
109109
self.active_target = self.target_fn()
110110

111-
def __next__(self):
111+
def __next__(self) -> T:
112112
"""
113113
Implement the iterator protocol.
114+
115+
Returns:
116+
- the next value of the active_target iterator
114117
"""
115118
try:
116119
return next(self.active_target)
@@ -119,7 +122,7 @@ def __next__(self):
119122
# if retryable exception was handled, try again with new active_target
120123
return self.__next__()
121124

122-
def close(self):
125+
def close(self) -> None:
123126
"""
124127
Close the active_target if supported. (e.g. target is a generator)
125128
@@ -133,22 +136,25 @@ def close(self):
133136
"close() not implemented for {}".format(self.active_target)
134137
)
135138

136-
def send(self, value):
139+
def send(self, *args, **kwargs) -> T:
137140
"""
138141
Call send on the active_target if supported. (e.g. target is a generator)
139142
140143
If an exception is raised, a retry may be attempted before returning
141144
a result.
142145
146+
Args:
147+
- *args: arguments to pass to the wrapped generator's send method
148+
- **kwargs: keyword arguments to pass to the wrapped generator's send method
143149
Returns:
144-
- the result of calling send() on the active_target
145-
150+
- the next value of the active_target iterator after calling send
146151
Raises:
147152
- AttributeError if the active_target does not have a send() method
148153
"""
149154
if getattr(self.active_target, "send", None):
155+
casted_target = cast(Generator, self.active_target)
150156
try:
151-
return self.active_target.send(value)
157+
return casted_target.send(*args, **kwargs)
152158
except Exception as exc:
153159
self._handle_exception(exc)
154160
# if exception was retryable, use new target for return value
@@ -158,21 +164,25 @@ def send(self, value):
158164
"send() not implemented for {}".format(self.active_target)
159165
)
160166

161-
def throw(self, typ, val=None, tb=None):
167+
def throw(self, *args, **kwargs) -> T:
162168
"""
163169
Call throw on the active_target if supported. (e.g. target is a generator)
164170
165171
If an exception is raised, a retry may be attempted before returning
166172
a result.
167173
174+
Args:
175+
- *args: arguments to pass to the wrapped generator's throw method
176+
- **kwargs: keyword arguments to pass to the wrapped generator's throw method
168177
Returns:
169-
- the result of calling throw() on the active_target
178+
- the next vale of the active_target iterator after calling throw
170179
Raises:
171180
- AttributeError if the active_target does not have a throw() method
172181
"""
173182
if getattr(self.active_target, "throw", None):
183+
casted_target = cast(Generator, self.active_target)
174184
try:
175-
return self.active_target.throw(typ, val, tb)
185+
return casted_target.throw(*args, **kwargs)
176186
except Exception as exc:
177187
self._handle_exception(exc)
178188
# if retryable exception was handled, return next from new active_target

google/api_core/retry_streaming_async.py

Lines changed: 59 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,35 @@
1414

1515
"""Helpers for retries for async streaming APIs."""
1616

17-
from typing import Callable, Optional, Iterable, AsyncIterable, Awaitable, Union
17+
from typing import (
18+
cast,
19+
Callable,
20+
Optional,
21+
Iterable,
22+
AsyncIterator,
23+
AsyncIterable,
24+
Awaitable,
25+
Union,
26+
Any,
27+
TypeVar,
28+
AsyncGenerator,
29+
)
1830

1931
import asyncio
2032
import inspect
2133
import logging
34+
import datetime
2235

23-
from collections.abc import AsyncGenerator
2436

2537
from google.api_core import datetime_helpers
2638
from google.api_core import exceptions
2739

2840
_LOGGER = logging.getLogger(__name__)
2941

42+
T = TypeVar("T")
3043

31-
class AsyncRetryableGenerator(AsyncGenerator):
44+
45+
class AsyncRetryableGenerator(AsyncGenerator[T, None]):
3246
"""
3347
Helper class for retrying AsyncIterator and AsyncGenerator-based
3448
streaming APIs.
@@ -37,7 +51,8 @@ class AsyncRetryableGenerator(AsyncGenerator):
3751
def __init__(
3852
self,
3953
target: Union[
40-
Callable[[], AsyncIterable], Callable[[], Awaitable[AsyncIterable]]
54+
Callable[[], AsyncIterable[T]],
55+
Callable[[], Awaitable[AsyncIterable[T]]],
4156
],
4257
predicate: Callable[[Exception], bool],
4358
sleep_generator: Iterable[float],
@@ -61,27 +76,32 @@ def __init__(
6176
"""
6277
self.target_fn = target
6378
# active target must be populated in an async context
64-
self.active_target: Optional[AsyncIterable] = None
79+
self.active_target: Optional[AsyncIterator[T]] = None
6580
self.predicate = predicate
6681
self.sleep_generator = iter(sleep_generator)
6782
self.on_error = on_error
6883
self.timeout = timeout
6984
self.remaining_timeout_budget = timeout if timeout else None
7085

71-
async def _ensure_active_target(self):
86+
async def _ensure_active_target(self) -> AsyncIterator[T]:
7287
"""
7388
Ensure that the active target is populated and ready to be iterated over.
89+
90+
Returns:
91+
- The active_target iterable
7492
"""
7593
if not self.active_target:
76-
self.active_target = self.target_fn()
77-
if inspect.iscoroutine(self.active_target):
78-
self.active_target = await self.active_target
94+
new_iterable = self.target_fn()
95+
if isinstance(new_iterable, Awaitable):
96+
new_iterable = await new_iterable
97+
self.active_target = new_iterable.__aiter__()
98+
return self.active_target
7999

80-
def __aiter__(self):
100+
def __aiter__(self) -> AsyncIterator[T]:
81101
"""Implement the async iterator protocol."""
82102
return self
83103

84-
async def _handle_exception(self, exc):
104+
async def _handle_exception(self, exc) -> None:
85105
"""
86106
When an exception is raised while iterating over the active_target,
87107
check if it is retryable. If so, create a new active_target and
@@ -114,7 +134,7 @@ async def _handle_exception(self, exc):
114134
self.active_target = None
115135
await self._ensure_active_target()
116136

117-
def _subtract_time_from_budget(self, start_timestamp):
137+
def _subtract_time_from_budget(self, start_timestamp: datetime.datetime) -> None:
118138
"""
119139
Subtract the time elapsed since start_timestamp from the remaining
120140
timeout budget.
@@ -128,13 +148,15 @@ def _subtract_time_from_budget(self, start_timestamp):
128148
datetime_helpers.utcnow() - start_timestamp
129149
).total_seconds()
130150

131-
async def _iteration_helper(self, iteration_routine: Awaitable):
151+
async def _iteration_helper(self, iteration_routine: Awaitable) -> T:
132152
"""
133153
Helper function for sharing logic between __anext__ and asend.
134154
135155
Args:
136156
- iteration_routine: The coroutine to await to get the next value
137157
from the iterator (e.g. __anext__ or asend)
158+
Returns:
159+
- The next value from the active_target iterator.
138160
"""
139161
# check for expired timeouts before attempting to iterate
140162
if (
@@ -164,16 +186,19 @@ async def _iteration_helper(self, iteration_routine: Awaitable):
164186
# if retryable exception was handled, find the next value to return
165187
return await self.__anext__()
166188

167-
async def __anext__(self):
189+
async def __anext__(self) -> T:
168190
"""
169191
Implement the async iterator protocol.
192+
193+
Returns:
194+
- The next value from the active_target iterator.
170195
"""
171-
await self._ensure_active_target()
196+
iterable = await self._ensure_active_target()
172197
return await self._iteration_helper(
173-
self.active_target.__anext__(),
198+
iterable.__anext__(),
174199
)
175200

176-
async def aclose(self):
201+
async def aclose(self) -> None:
177202
"""
178203
Close the active_target if supported. (e.g. target is an async generator)
179204
@@ -182,48 +207,57 @@ async def aclose(self):
182207
"""
183208
await self._ensure_active_target()
184209
if getattr(self.active_target, "aclose", None):
185-
return await self.active_target.aclose()
210+
casted_target = cast(AsyncGenerator[T, None], self.active_target)
211+
return await casted_target.aclose()
186212
else:
187213
raise AttributeError(
188214
"aclose() not implemented for {}".format(self.active_target)
189215
)
190216

191-
async def asend(self, value):
217+
async def asend(self, *args, **kwargs) -> T:
192218
"""
193219
Call asend on the active_target if supported. (e.g. target is an async generator)
194220
195221
If an exception is raised, a retry may be attempted before returning
196222
a result.
197223
198-
Returns:
199-
- the result of calling asend() on the active_target
200224
225+
Args:
226+
- *args: arguments to pass to the wrapped generator's asend method
227+
- **kwargs: keyword arguments to pass to the wrapped generator's asend method
228+
Returns:
229+
- the next value of the active_target iterator after calling asend
201230
Raises:
202231
- AttributeError if the active_target does not have a asend() method
203232
"""
204233
await self._ensure_active_target()
205234
if getattr(self.active_target, "asend", None):
206-
return await self._iteration_helper(self.active_target.asend(value))
235+
casted_target = cast(AsyncGenerator[T, None], self.active_target)
236+
return await self._iteration_helper(casted_target.asend(*args, **kwargs))
207237
else:
208238
raise AttributeError(
209239
"asend() not implemented for {}".format(self.active_target)
210240
)
211241

212-
async def athrow(self, typ, val=None, tb=None):
242+
async def athrow(self, *args, **kwargs) -> T:
213243
"""
214244
Call athrow on the active_target if supported. (e.g. target is an async generator)
215245
216246
If an exception is raised, a retry may be attempted before returning
217247
248+
Args:
249+
- *args: arguments to pass to the wrapped generator's athrow method
250+
- **kwargs: keyword arguments to pass to the wrapped generator's athrow method
218251
Returns:
219-
- the result of calling athrow() on the active_target
252+
- the next value of the active_target iterator after calling athrow
220253
Raises:
221254
- AttributeError if the active_target does not have a athrow() method
222255
"""
223256
await self._ensure_active_target()
224257
if getattr(self.active_target, "athrow", None):
258+
casted_target = cast(AsyncGenerator[T, None], self.active_target)
225259
try:
226-
return await self.active_target.athrow(typ, val, tb)
260+
return await casted_target.athrow(*args, **kwargs)
227261
except Exception as exc:
228262
await self._handle_exception(exc)
229263
# if retryable exception was handled, return next from new active_target

0 commit comments

Comments
 (0)