Skip to content

Commit 74f3f3e

Browse files
committed
fixed send and asend retry logic
1 parent 902a4ab commit 74f3f3e

File tree

4 files changed

+53
-24
lines changed

4 files changed

+53
-24
lines changed

google/api_core/retry_streaming.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ def send(self, value):
151151
return self.active_target.send(value)
152152
except Exception as exc:
153153
self._handle_exception(exc)
154-
# if retryable exception was handled, try again with new active_target
155-
return self.send(value)
154+
# if exception was retryable, use new target for return value
155+
return self.__next__()
156156
else:
157157
raise AttributeError(
158158
"send() not implemented for {}".format(self.active_target)

google/api_core/retry_streaming_async.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -128,20 +128,13 @@ def _subtract_time_from_budget(self, start_timestamp):
128128
datetime_helpers.utcnow() - start_timestamp
129129
).total_seconds()
130130

131-
async def _iteration_helper(
132-
self, iteration_fn: Callable[..., Awaitable], try_again_fn: Callable, *args
133-
):
131+
async def _iteration_helper(self, iteration_routine: Awaitable):
134132
"""
135133
Helper function for sharing logic between __anext__ and asend.
136134
137135
Args:
138-
- iteration_fn: The function to call to get the next value from the
139-
iterator (e.g. __anext__ or asend)
140-
- try_again_fn: The function to call after a retryable exception is
141-
encountered, to get a value from the new active_target
142-
(e.g. self.__anext__ or self.asend)
143-
- *args: Any additional arguments to pass to iteration_fn and
144-
try_again_fn (e.g. the value to send to asend)
136+
- iteration_routine: The coroutine to await to get the next value
137+
from the iterator (e.g. __anext__ or asend)
145138
"""
146139
# check for expired timeouts before attempting to iterate
147140
if (
@@ -158,7 +151,7 @@ async def _iteration_helper(
158151
start_timestamp = datetime_helpers.utcnow()
159152
# grab the next value from the active_target
160153
next_val_routine = asyncio.wait_for(
161-
iteration_fn(*args), self.remaining_timeout_budget
154+
iteration_routine, self.remaining_timeout_budget
162155
)
163156
next_val = await next_val_routine
164157
# subtract the time spent waiting for the next value from the
@@ -168,16 +161,16 @@ async def _iteration_helper(
168161
except (Exception, asyncio.CancelledError) as exc:
169162
self._subtract_time_from_budget(start_timestamp)
170163
await self._handle_exception(exc)
171-
# if retryable exception was handled, try again with new active_target
172-
return await try_again_fn(*args)
164+
# if retryable exception was handled, find the next value to return
165+
return await self.__anext__()
173166

174167
async def __anext__(self):
175168
"""
176169
Implement the async iterator protocol.
177170
"""
178171
await self._ensure_active_target()
179172
return await self._iteration_helper(
180-
self.active_target.__anext__, self.__anext__
173+
self.active_target.__anext__(),
181174
)
182175

183176
async def aclose(self):
@@ -210,9 +203,7 @@ async def asend(self, value):
210203
"""
211204
await self._ensure_active_target()
212205
if getattr(self.active_target, "asend", None):
213-
return await self._iteration_helper(
214-
self.active_target.asend, self.asend, value
215-
)
206+
return await self._iteration_helper(self.active_target.asend(value))
216207
else:
217208
raise AttributeError(
218209
"asend() not implemented for {}".format(self.active_target)

tests/asyncio/test_retry_async.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,12 @@ async def test___init___when_retry_is_executed(self, sleep, uniform):
421421
sleep.assert_any_call(retry_._initial)
422422

423423
async def _generator_mock(
424-
self, num=5, error_on=None, exceptions_seen=None, sleep_time=0
424+
self,
425+
num=5,
426+
error_on=None,
427+
exceptions_seen=None,
428+
sleep_time=0,
429+
ignore_sent=False,
425430
):
426431
try:
427432
sent_in = None
@@ -431,6 +436,8 @@ async def _generator_mock(
431436
if error_on and i == error_on:
432437
raise ValueError("generator mock error")
433438
sent_in = yield (sent_in if sent_in else i)
439+
if ignore_sent:
440+
sent_in = None
434441
except (Exception, BaseException, GeneratorExit) as e:
435442
# keep track of exceptions seen by generator
436443
if exceptions_seen is not None:
@@ -607,6 +614,28 @@ async def test___call___with_generator_send(self, sleep):
607614
assert await generator.__anext__() == 4
608615
assert await generator.__anext__() == 5
609616

617+
@mock.patch("asyncio.sleep", autospec=True)
618+
@pytest.mark.asyncio
619+
async def test___call___generator_send_retry(self, sleep):
620+
on_error = mock.Mock(return_value=None)
621+
retry_ = retry_async.AsyncRetry(
622+
on_error=on_error,
623+
predicate=retry_async.if_exception_type(ValueError),
624+
is_stream=True,
625+
timeout=None,
626+
)
627+
generator = retry_(self._generator_mock)(error_on=3, ignore_sent=True)
628+
with pytest.raises(TypeError) as exc_info:
629+
await generator.asend("can not send to fresh generator")
630+
assert exc_info.match("can't send non-None value")
631+
632+
# error thrown on 3
633+
# generator should contain 0, 1, 2 looping
634+
assert await generator.__anext__() == 0
635+
unpacked = [await generator.asend(i) for i in range(10)]
636+
assert unpacked == [1, 2, 0, 1, 2, 0, 1, 2, 0, 1]
637+
assert on_error.call_count == 3
638+
610639
@mock.patch("asyncio.sleep", autospec=True)
611640
@pytest.mark.asyncio
612641
async def test___call___with_generator_close(self, sleep):

tests/unit/test_retry.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -484,14 +484,21 @@ def test___init___when_retry_is_executed(self, sleep, uniform):
484484
sleep.assert_any_call(retry_._initial)
485485

486486
def _generator_mock(
487-
self, num=5, error_on=None, return_val=None, exceptions_seen=None
487+
self,
488+
num=5,
489+
error_on=None,
490+
return_val=None,
491+
exceptions_seen=None,
492+
ignore_sent=False,
488493
):
489494
try:
490495
sent_in = None
491496
for i in range(num):
492497
if error_on and i == error_on:
493498
raise ValueError("generator mock error")
494499
sent_in = yield (sent_in if sent_in else i)
500+
if ignore_sent:
501+
sent_in = None
495502
return return_val
496503
except (Exception, BaseException, GeneratorExit) as e:
497504
# keep track of exceptions seen by generator
@@ -610,14 +617,16 @@ def test___call___with_generator_send_retry(self, sleep):
610617
is_stream=True,
611618
timeout=None,
612619
)
613-
result = retry_(self._generator_mock)(error_on=3)
620+
result = retry_(self._generator_mock)(error_on=3, ignore_sent=True)
614621
with pytest.raises(TypeError) as exc_info:
615622
result.send("can not send to fresh generator")
616623
assert exc_info.match("can't send non-None value")
624+
# initiate iteration with None
625+
assert result.send(None) == 0
617626
# error thrown on 3
618627
# generator should contain 0, 1, 2 looping
619-
unpacked = [result.send(None) for i in range(10)]
620-
assert unpacked == [0, 1, 2, 0, 1, 2, 0, 1, 2, 0]
628+
unpacked = [result.send(i) for i in range(10)]
629+
assert unpacked == [1, 2, 0, 1, 2, 0, 1, 2, 0, 1]
621630
assert on_error.call_count == 3
622631

623632
@mock.patch("time.sleep", autospec=True)

0 commit comments

Comments
 (0)