14
14
15
15
"""Helpers for retries for async streaming APIs."""
16
16
17
- from typing import Callable , Optional , Iterable , AsyncIterable , Awaitable , Union
17
+ from typing import (
18
+ cast ,
19
+ Callable ,
20
+ Optional ,
21
+ Iterable ,
22
+ AsyncIterator ,
23
+ AsyncIterable ,
24
+ Awaitable ,
25
+ Union ,
26
+ Any ,
27
+ TypeVar ,
28
+ AsyncGenerator ,
29
+ )
18
30
19
31
import asyncio
20
32
import inspect
21
33
import logging
34
+ import datetime
22
35
23
- from collections .abc import AsyncGenerator
24
36
25
37
from google .api_core import datetime_helpers
26
38
from google .api_core import exceptions
27
39
28
40
_LOGGER = logging .getLogger (__name__ )
29
41
42
+ T = TypeVar ("T" )
30
43
31
- class AsyncRetryableGenerator (AsyncGenerator ):
44
+
45
+ class AsyncRetryableGenerator (AsyncGenerator [T , None ]):
32
46
"""
33
47
Helper class for retrying AsyncIterator and AsyncGenerator-based
34
48
streaming APIs.
@@ -37,7 +51,8 @@ class AsyncRetryableGenerator(AsyncGenerator):
37
51
def __init__ (
38
52
self ,
39
53
target : Union [
40
- Callable [[], AsyncIterable ], Callable [[], Awaitable [AsyncIterable ]]
54
+ Callable [[], AsyncIterable [T ]],
55
+ Callable [[], Awaitable [AsyncIterable [T ]]],
41
56
],
42
57
predicate : Callable [[Exception ], bool ],
43
58
sleep_generator : Iterable [float ],
@@ -61,27 +76,32 @@ def __init__(
61
76
"""
62
77
self .target_fn = target
63
78
# active target must be populated in an async context
64
- self .active_target : Optional [AsyncIterable ] = None
79
+ self .active_target : Optional [AsyncIterator [ T ] ] = None
65
80
self .predicate = predicate
66
81
self .sleep_generator = iter (sleep_generator )
67
82
self .on_error = on_error
68
83
self .timeout = timeout
69
84
self .remaining_timeout_budget = timeout if timeout else None
70
85
71
- async def _ensure_active_target (self ):
86
+ async def _ensure_active_target (self ) -> AsyncIterator [ T ] :
72
87
"""
73
88
Ensure that the active target is populated and ready to be iterated over.
89
+
90
+ Returns:
91
+ - The active_target iterable
74
92
"""
75
93
if not self .active_target :
76
- self .active_target = self .target_fn ()
77
- if inspect .iscoroutine (self .active_target ):
78
- self .active_target = await self .active_target
94
+ new_iterable = self .target_fn ()
95
+ if isinstance (new_iterable , Awaitable ):
96
+ new_iterable = await new_iterable
97
+ self .active_target = new_iterable .__aiter__ ()
98
+ return self .active_target
79
99
80
- def __aiter__ (self ):
100
+ def __aiter__ (self ) -> AsyncIterator [ T ] :
81
101
"""Implement the async iterator protocol."""
82
102
return self
83
103
84
- async def _handle_exception (self , exc ):
104
+ async def _handle_exception (self , exc ) -> None :
85
105
"""
86
106
When an exception is raised while iterating over the active_target,
87
107
check if it is retryable. If so, create a new active_target and
@@ -114,7 +134,7 @@ async def _handle_exception(self, exc):
114
134
self .active_target = None
115
135
await self ._ensure_active_target ()
116
136
117
- def _subtract_time_from_budget (self , start_timestamp ) :
137
+ def _subtract_time_from_budget (self , start_timestamp : datetime . datetime ) -> None :
118
138
"""
119
139
Subtract the time elapsed since start_timestamp from the remaining
120
140
timeout budget.
@@ -128,13 +148,15 @@ def _subtract_time_from_budget(self, start_timestamp):
128
148
datetime_helpers .utcnow () - start_timestamp
129
149
).total_seconds ()
130
150
131
- async def _iteration_helper (self , iteration_routine : Awaitable ):
151
+ async def _iteration_helper (self , iteration_routine : Awaitable ) -> T :
132
152
"""
133
153
Helper function for sharing logic between __anext__ and asend.
134
154
135
155
Args:
136
156
- iteration_routine: The coroutine to await to get the next value
137
157
from the iterator (e.g. __anext__ or asend)
158
+ Returns:
159
+ - The next value from the active_target iterator.
138
160
"""
139
161
# check for expired timeouts before attempting to iterate
140
162
if (
@@ -164,16 +186,19 @@ async def _iteration_helper(self, iteration_routine: Awaitable):
164
186
# if retryable exception was handled, find the next value to return
165
187
return await self .__anext__ ()
166
188
167
- async def __anext__ (self ):
189
+ async def __anext__ (self ) -> T :
168
190
"""
169
191
Implement the async iterator protocol.
192
+
193
+ Returns:
194
+ - The next value from the active_target iterator.
170
195
"""
171
- await self ._ensure_active_target ()
196
+ iterable = await self ._ensure_active_target ()
172
197
return await self ._iteration_helper (
173
- self . active_target .__anext__ (),
198
+ iterable .__anext__ (),
174
199
)
175
200
176
- async def aclose (self ):
201
+ async def aclose (self ) -> None :
177
202
"""
178
203
Close the active_target if supported. (e.g. target is an async generator)
179
204
@@ -182,48 +207,57 @@ async def aclose(self):
182
207
"""
183
208
await self ._ensure_active_target ()
184
209
if getattr (self .active_target , "aclose" , None ):
185
- return await self .active_target .aclose ()
210
+ casted_target = cast (AsyncGenerator [T , None ], self .active_target )
211
+ return await casted_target .aclose ()
186
212
else :
187
213
raise AttributeError (
188
214
"aclose() not implemented for {}" .format (self .active_target )
189
215
)
190
216
191
- async def asend (self , value ) :
217
+ async def asend (self , * args , ** kwargs ) -> T :
192
218
"""
193
219
Call asend on the active_target if supported. (e.g. target is an async generator)
194
220
195
221
If an exception is raised, a retry may be attempted before returning
196
222
a result.
197
223
198
- Returns:
199
- - the result of calling asend() on the active_target
200
224
225
+ Args:
226
+ - *args: arguments to pass to the wrapped generator's asend method
227
+ - **kwargs: keyword arguments to pass to the wrapped generator's asend method
228
+ Returns:
229
+ - the next value of the active_target iterator after calling asend
201
230
Raises:
202
231
- AttributeError if the active_target does not have a asend() method
203
232
"""
204
233
await self ._ensure_active_target ()
205
234
if getattr (self .active_target , "asend" , None ):
206
- return await self ._iteration_helper (self .active_target .asend (value ))
235
+ casted_target = cast (AsyncGenerator [T , None ], self .active_target )
236
+ return await self ._iteration_helper (casted_target .asend (* args , ** kwargs ))
207
237
else :
208
238
raise AttributeError (
209
239
"asend() not implemented for {}" .format (self .active_target )
210
240
)
211
241
212
- async def athrow (self , typ , val = None , tb = None ) :
242
+ async def athrow (self , * args , ** kwargs ) -> T :
213
243
"""
214
244
Call athrow on the active_target if supported. (e.g. target is an async generator)
215
245
216
246
If an exception is raised, a retry may be attempted before returning
217
247
248
+ Args:
249
+ - *args: arguments to pass to the wrapped generator's athrow method
250
+ - **kwargs: keyword arguments to pass to the wrapped generator's athrow method
218
251
Returns:
219
- - the result of calling athrow() on the active_target
252
+ - the next value of the active_target iterator after calling athrow
220
253
Raises:
221
254
- AttributeError if the active_target does not have a athrow() method
222
255
"""
223
256
await self ._ensure_active_target ()
224
257
if getattr (self .active_target , "athrow" , None ):
258
+ casted_target = cast (AsyncGenerator [T , None ], self .active_target )
225
259
try :
226
- return await self . active_target . athrow (typ , val , tb )
260
+ return await casted_target . athrow (* args , ** kwargs )
227
261
except Exception as exc :
228
262
await self ._handle_exception (exc )
229
263
# if retryable exception was handled, return next from new active_target
0 commit comments