Skip to content

Commit 9ba7676

Browse files
committed
simplified retry_streaming_async to use wall time instead of cpu time
1 parent 43d0913 commit 9ba7676

File tree

2 files changed

+40
-49
lines changed

2 files changed

+40
-49
lines changed

google/api_core/retry_streaming_async.py

Lines changed: 39 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929

3030
import asyncio
3131
import logging
32-
import time
33-
32+
import datetime
3433

34+
from google.api_core import datetime_helpers
3535
from google.api_core import exceptions
3636

3737
_LOGGER = logging.getLogger(__name__)
@@ -78,7 +78,35 @@ def __init__(
7878
self.sleep_generator = iter(sleep_generator)
7979
self.on_error = on_error
8080
self.timeout = timeout
81-
self.remaining_timeout_budget = timeout if timeout else None
81+
self.timeout_task = None
82+
if self.timeout is not None:
83+
self.deadline = datetime_helpers.utcnow() + datetime.timedelta(
84+
seconds=self.timeout
85+
)
86+
else:
87+
self.deadline = None
88+
89+
def _check_timeout(
90+
self, current_time: float, source_exception: Optional[Exception] = None
91+
) -> None:
92+
"""
93+
Helper function to check if the timeout has been exceeded, and raise a RetryError if so.
94+
95+
Args:
96+
- current_time: the timestamp to check against the deadline
97+
- source_exception: the exception that triggered the timeout check, if any
98+
Raises:
99+
- RetryError if the deadline has been exceeded
100+
"""
101+
if (
102+
self.deadline is not None
103+
and self.timeout is not None
104+
and self.deadline < current_time
105+
):
106+
raise exceptions.RetryError(
107+
"Timeout of {:.1f}s exceeded".format(self.timeout),
108+
source_exception,
109+
) from source_exception
82110

83111
async def _ensure_active_target(self) -> AsyncIterator[T]:
84112
"""
@@ -114,15 +142,12 @@ async def _handle_exception(self, exc) -> None:
114142
next_sleep = next(self.sleep_generator)
115143
except StopIteration:
116144
raise ValueError("Sleep generator stopped yielding sleep values")
117-
# if time budget is exceeded, raise RetryError
118-
if self.remaining_timeout_budget is not None and self.timeout is not None:
119-
if self.remaining_timeout_budget <= next_sleep:
120-
raise exceptions.RetryError(
121-
"Timeout of {:.1f}s exceeded".format(self.timeout),
122-
exc,
123-
) from exc
124-
else:
125-
self.remaining_timeout_budget -= next_sleep
145+
# if deadline is exceeded, raise RetryError
146+
if self.deadline is not None:
147+
next_attempt = datetime_helpers.utcnow() + datetime.timedelta(
148+
seconds=next_sleep
149+
)
150+
self._check_timeout(next_attempt, exc)
126151
_LOGGER.debug(
127152
"Retrying due to {}, sleeping {:.1f}s ...".format(exc, next_sleep)
128153
)
@@ -131,18 +156,6 @@ async def _handle_exception(self, exc) -> None:
131156
self.active_target = None
132157
await self._ensure_active_target()
133158

134-
def _subtract_time_from_budget(self, start_timestamp: float) -> None:
135-
"""
136-
Subtract the time elapsed since start_timestamp from the remaining
137-
timeout budget.
138-
139-
Args:
140-
- start_timestamp: The timestamp at which the last operation
141-
started.
142-
"""
143-
if self.remaining_timeout_budget is not None:
144-
self.remaining_timeout_budget -= time.monotonic() - start_timestamp
145-
146159
async def _iteration_helper(self, iteration_routine: Awaitable) -> T:
147160
"""
148161
Helper function for sharing logic between __anext__ and asend.
@@ -154,28 +167,13 @@ async def _iteration_helper(self, iteration_routine: Awaitable) -> T:
154167
- The next value from the active_target iterator.
155168
"""
156169
# check for expired timeouts before attempting to iterate
157-
if (
158-
self.remaining_timeout_budget is not None
159-
and self.remaining_timeout_budget <= 0
160-
and self.timeout is not None
161-
):
162-
raise exceptions.RetryError(
163-
"Timeout of {:.1f}s exceeded".format(self.timeout),
164-
None,
165-
)
170+
self._check_timeout(datetime_helpers.utcnow())
166171
try:
167-
# start the timer for the current operation
168-
start_timestamp = time.monotonic()
169172
# grab the next value from the active_target
170173
# Note: interrupting with asyncio.wait_for is expensive,
171174
# so we only check for timeouts at the start of each iteration
172-
next_val = await iteration_routine
173-
# subtract the time spent waiting for the next value from the
174-
# remaining timeout budget
175-
self._subtract_time_from_budget(start_timestamp)
176-
return next_val
175+
return await iteration_routine
177176
except (Exception, asyncio.CancelledError) as exc:
178-
self._subtract_time_from_budget(start_timestamp)
179177
await self._handle_exception(exc)
180178
# if retryable exception was handled, find the next value to return
181179
return await self.__anext__()

tests/asyncio/test_retry_async.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -765,14 +765,7 @@ async def test_iterate_stream_after_deadline(self):
765765
retry_ = retry_async.AsyncRetry(is_stream=True, deadline=0.01)
766766
decorated = retry_(self._generator_mock)
767767
generator = decorated(10)
768-
starting_time_budget = generator.remaining_timeout_budget
769-
assert starting_time_budget == 0.01
770768
await generator.__anext__()
771-
# ensure budget is used on each call
772-
assert generator.remaining_timeout_budget < starting_time_budget
773-
# simulate using up budget
774-
generator.remaining_timeout_budget = 0
769+
await asyncio.sleep(0.02)
775770
with pytest.raises(exceptions.RetryError):
776771
await generator.__anext__()
777-
with pytest.raises(exceptions.RetryError):
778-
await generator.asend("test")

0 commit comments

Comments
 (0)