Skip to content

Commit 0500b8b

Browse files
committed
fixed mypy issues
1 parent ba6dc9f commit 0500b8b

File tree

4 files changed

+30
-39
lines changed

4 files changed

+30
-39
lines changed

google/api_core/retry_streaming.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
Generator,
2525
TypeVar,
2626
Any,
27+
Union,
2728
cast,
2829
)
2930

@@ -38,13 +39,6 @@
3839
T = TypeVar("T")
3940

4041

41-
class _TerminalException(Exception):
42-
"""
43-
Exception to bypasses retry logic and raises __cause__ immediately.
44-
"""
45-
pass
46-
47-
4842
def _build_timeout_error(
4943
exc_list: List[Exception], is_timeout: bool, timeout_val: float
5044
) -> Tuple[Exception, Optional[Exception]]:
@@ -74,15 +68,13 @@ def _build_timeout_error(
7468

7569

7670
def retry_target_generator(
77-
target: Callable[[], Iterable[T]],
71+
target: Callable[[], Union[Iterable[T], Generator[T, Any, None]]],
7872
predicate: Callable[[Exception], bool],
7973
sleep_generator: Iterable[float],
8074
timeout: Optional[float] = None,
8175
on_error: Optional[Callable[[Exception], None]] = None,
8276
exception_factory: Optional[
83-
Callable[
84-
[List[Exception], bool, float], Tuple[Exception, Optional[Exception]]
85-
]
77+
Callable[[List[Exception], bool, float], Tuple[Exception, Optional[Exception]]]
8678
] = None,
8779
**kwargs,
8880
) -> Generator[T, Any, None]:
@@ -171,17 +163,19 @@ def on_error(e):
171163
except Exception as exc:
172164
error_list.append(exc)
173165
if not predicate(exc):
174-
exc, source_exc = exc_factory(exc_list=error_list, is_timeout=False)
175-
raise exc from source_exc
166+
final_exc, source_exc = exc_factory(
167+
exc_list=error_list, is_timeout=False
168+
)
169+
raise final_exc from source_exc
176170
if on_error is not None:
177171
on_error(exc)
178172
finally:
179173
if subgenerator is not None and getattr(subgenerator, "close", None):
180-
subgenerator.close()
174+
cast(Generator, subgenerator).close()
181175

182176
if deadline is not None and time.monotonic() + sleep > deadline:
183-
exc, source_exc = exc_factory(exc_list=error_list, is_timeout=True)
184-
raise exc from source_exc
177+
final_exc, source_exc = exc_factory(exc_list=error_list, is_timeout=True)
178+
raise final_exc from source_exc
185179
_LOGGER.debug(
186180
"Retrying due to {}, sleeping {:.1f}s ...".format(error_list[-1], sleep)
187181
)

google/api_core/retry_streaming_async.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,6 @@
4242
T = TypeVar("T")
4343

4444

45-
class _TerminalException(Exception):
46-
"""
47-
Exception to bypasses retry logic and raises __cause__ immediately.
48-
"""
49-
pass
50-
51-
5245
async def retry_target_generator(
5346
target: Union[
5447
Callable[[], AsyncIterable[T]],
@@ -133,12 +126,12 @@ def on_error(e):
133126
filter_retry_wrapped = retryable_with_filter(target)
134127
```
135128
"""
136-
subgenerator = None
137-
129+
subgenerator : Optional[AsyncIterator[T]] = None
138130
timeout = kwargs.get("deadline", timeout)
139-
140131
deadline: Optional[float] = time.monotonic() + timeout if timeout else None
132+
# keep track of retryable exceptions we encounter to pass in to exception_factory
141133
error_list: List[Exception] = []
134+
# override exception_factory to build a more complex exception
142135
exc_factory = partial(
143136
exception_factory or _build_timeout_error, timeout_val=timeout
144137
)
@@ -147,12 +140,13 @@ def on_error(e):
147140
# Start a new retry loop
148141
try:
149142
# generator may be raw iterator, or wrapped in an awaitable
150-
subgenerator = target()
143+
gen_instance: Union[AsyncIterable[T], Awaitable[AsyncIterable[T]]] = target()
151144
try:
152-
subgenerator = await subgenerator
145+
gen_instance = await gen_instance # type: ignore
153146
except TypeError:
154147
# was not awaitable
155148
pass
149+
subgenerator = cast(AsyncIterable[T], gen_instance).__aiter__()
156150

157151
# if target is a generator, we will advance it using asend
158152
# otherwise, we will use anext
@@ -162,7 +156,7 @@ def on_error(e):
162156
while True:
163157
## Read from Subgenerator
164158
if supports_send:
165-
next_value = await subgenerator.asend(sent_in)
159+
next_value = await subgenerator.asend(sent_in) # type: ignore
166160
else:
167161
next_value = await subgenerator.__anext__()
168162
## Yield from Wrapper to caller
@@ -173,20 +167,18 @@ def on_error(e):
173167
except GeneratorExit:
174168
# if wrapper received `aclose`, pass to subgenerator and close
175169
if bool(getattr(subgenerator, "aclose", None)):
176-
await subgenerator.aclose()
170+
await cast(AsyncGenerator[T, None], subgenerator).aclose()
177171
else:
178172
raise
179173
return
180174
except: # noqa: E722
181175
# bare except catches any exception passed to `athrow`
182176
# delegate error handling to subgenerator
183177
if getattr(subgenerator, "athrow", None):
184-
await subgenerator.athrow(*sys.exc_info())
178+
await cast(AsyncGenerator[T, None], subgenerator).athrow(*sys.exc_info())
185179
else:
186180
raise
187181
return
188-
except _TerminalException as exc:
189-
raise exc.__cause__ from exc.__cause__.__cause__
190182
except StopAsyncIteration:
191183
# if generator exhausted, return
192184
return
@@ -201,12 +193,12 @@ def on_error(e):
201193
on_error(exc)
202194
finally:
203195
if subgenerator is not None and getattr(subgenerator, "aclose", None):
204-
await subgenerator.aclose()
196+
await cast(AsyncGenerator[T, None], subgenerator).aclose()
205197

206198
# sleep and adjust timeout budget
207199
if deadline is not None and time.monotonic() + sleep > deadline:
208-
exc, source_exc = exc_factory(exc_list=error_list, is_timeout=True)
209-
raise exc from source_exc
200+
final_exc, source_exc = exc_factory(exc_list=error_list, is_timeout=True)
201+
raise final_exc from source_exc
210202
_LOGGER.debug(
211203
"Retrying due to {}, sleeping {:.1f}s ...".format(error_list[-1], sleep)
212204
)

tests/asyncio/test_retry_async.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -705,8 +705,10 @@ async def __anext__(self):
705705
return CustomIterable(n)
706706

707707
if awaitale_wrapped:
708+
708709
async def wrapper(n):
709710
return iterable_fn(n)
711+
710712
decorated = retry_(wrapper)
711713
else:
712714
decorated = retry_(iterable_fn)
@@ -718,7 +720,6 @@ async def wrapper(n):
718720
await retryable.asend("test2") == 2
719721
await retryable.asend("test3") == 3
720722

721-
722723
@pytest.mark.parametrize("awaitale_wrapped", [True, False])
723724
@mock.patch("asyncio.sleep", autospec=True)
724725
@pytest.mark.asyncio
@@ -744,9 +745,12 @@ async def __anext__(self):
744745
return self.i - 1
745746

746747
return CustomIterable(n)
748+
747749
if awaitale_wrapped:
750+
748751
async def wrapper(n):
749752
return iterable_fn(n)
753+
750754
decorated = retry_(wrapper)
751755
else:
752756
decorated = retry_(iterable_fn)
@@ -763,7 +767,6 @@ async def wrapper(n):
763767
with pytest.raises(StopAsyncIteration):
764768
await new_retryable.__anext__()
765769

766-
767770
@pytest.mark.parametrize("awaitale_wrapped", [True, False])
768771
@mock.patch("asyncio.sleep", autospec=True)
769772
@pytest.mark.asyncio
@@ -791,9 +794,12 @@ async def __anext__(self):
791794
return self.i - 1
792795

793796
return CustomIterable(n)
797+
794798
if awaitale_wrapped:
799+
795800
async def wrapper(n):
796801
return iterable_fn(n)
802+
797803
decorated = retry_(wrapper)
798804
else:
799805
decorated = retry_(iterable_fn)

tests/unit/test_retry.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,6 @@ def iterable_fn(n):
668668
with pytest.raises(AttributeError):
669669
generator.send("test")
670670

671-
672671
@mock.patch("time.sleep", autospec=True)
673672
def test___call___with_iterable_close(self, sleep):
674673
"""

0 commit comments

Comments
 (0)