Skip to content

Commit 988a52d

Browse files
khakhlyukhaoyangeng-db
authored andcommitted
[SPARK-52673][CONNECT][CLIENT] Add grpc RetryInfo handling to Spark Connect retry policies
### What changes were proposed in this pull request? Spark Connect Client has a set of retry policies that specify which errors coming from the Server can be retried. This change adds the capability for the Spark Connect Client to use server-provided retry information according to the grpc standards: https://github.com/googleapis/googleapis/blob/master/google/rpc/error_details.proto#L91 The server can include `RetryInfo` gRPC message containing `retry_delay` field in its error response. The Client will now use `RetryInfo` message to classify the error as retriable and will use `retry_delay` to calculate the next time to wait. This behavior is in line with the gRPC standard for client-server communication. The change is needed for two reasons: 1) If the Server is under heavy load or a task takes more time, it can tell the client to wait longer using the `retry_delay` field. 2) If the Server needs to introduce a new retryable error, it can simply include `RetryInfo` in the error message. The error message will be retried automatically by the client. No changes to the client-side retry policies are needed to retry the new error. #### Changes in detail - Adding new `recognize_server_retry_delay` and `max_server_retry_delay` options for `RetryPolicy` classes in Python and Scala clients. - All policies with `recognize_server_retry_delay=True` will take `RetryInfo.retry_delay` into account when calculating the next backoff. - `retry_delay` can override client's `max_backoff` - `retry_delay` is limited by `max_server_retry_delay` (10 minutes by default). - When the server stops sending high retry_delays, the client goes back to using its own backoff policy limited by `max_backoff`. - `DefaultPolicy` has `recognize_server_retry_delay=True` and will use `retry_delay` in the backoff calculation. - Additionally, DefaultPolicy will classify all errors with `RetryInfo` as retryable. - If an error message can be retried by several policies, only retry it with the first one (highest prio) and then stop. This change is needed because `DefaultPolicy` now retries all errors with `RetryInfo`. If we keep the existing behaviour, an error that is both has the `RetryInfo` and is matched by a different `CustomPolicy`, would be retried both by the `DefaultPolicy` and by the `CustomPolicy`. This can lead to excessively long retry periods and complicates the planning of total retry times. - Moving retry policy related tests from `test_client.py` to a new `test_client_retries.py` file. Same for scala. - Extending docstrings. ### Why are the changes needed? See above ### Does this PR introduce _any_ user-facing change? 1. The clients retry all errors with `RetryInfo` grpc message using the DefaultPolicy. 2. The error is only retried by the first policy that matches it. ### How was this patch tested? Old and new unit tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes apache#51363 from khakhlyuk/retryinfo. Authored-by: Alex Khakhlyuk <alex.khakhlyuk@gmail.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent 10c7e3b commit 988a52d

File tree

8 files changed

+662
-172
lines changed

8 files changed

+662
-172
lines changed

python/pyspark/sql/connect/client/retries.py

Lines changed: 81 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
import random
2020
import time
2121
import typing
22-
from typing import Optional, Callable, Generator, List, Type
22+
from google.rpc import error_details_pb2
23+
from grpc_status import rpc_status
24+
from typing import Optional, Callable, Generator, List, Type, cast
2325
from types import TracebackType
2426
from pyspark.sql.connect.logging import logger
2527
from pyspark.errors import PySparkRuntimeError, RetriesExceeded
@@ -45,6 +47,34 @@ class RetryPolicy:
4547
Describes key aspects of RetryPolicy.
4648
4749
It's advised that different policies are implemented as different subclasses.
50+
51+
Parameters
52+
----------
53+
max_retries: int, optional
54+
Maximum number of retries.
55+
initial_backoff: int
56+
Start value of the exponential backoff.
57+
max_backoff: int, optional
58+
Maximal value of the exponential backoff.
59+
backoff_multiplier: float
60+
Multiplicative base of the exponential backoff.
61+
jitter: int
62+
Sample a random value uniformly from the range [0, jitter] and add it to the backoff.
63+
min_jitter_threshold: int
64+
Minimal value of the backoff to add random jitter.
65+
recognize_server_retry_delay: bool
66+
Per gRPC standard, the server can send error messages that contain `RetryInfo` message
67+
with `retry_delay` field indicating that the client should wait for at least `retry_delay`
68+
amount of time before retrying again, see:
69+
https://github.com/googleapis/googleapis/blob/master/google/rpc/error_details.proto#L91
70+
71+
If this flag is set to true, RetryPolicy will use `RetryInfo.retry_delay` field
72+
in the backoff computation. Server's `retry_delay` can override client's `max_backoff`.
73+
74+
This flag does not change which errors are retried, only how the backoff is computed.
75+
`DefaultPolicy` additionally has a rule for retrying any error that contains `RetryInfo`.
76+
max_server_retry_delay: int, optional
77+
Limit for the server-provided `retry_delay`.
4878
"""
4979

5080
def __init__(
@@ -55,13 +85,17 @@ def __init__(
5585
backoff_multiplier: float = 1.0,
5686
jitter: int = 0,
5787
min_jitter_threshold: int = 0,
88+
recognize_server_retry_delay: bool = False,
89+
max_server_retry_delay: Optional[int] = None,
5890
):
5991
self.max_retries = max_retries
6092
self.initial_backoff = initial_backoff
6193
self.max_backoff = max_backoff
6294
self.backoff_multiplier = backoff_multiplier
6395
self.jitter = jitter
6496
self.min_jitter_threshold = min_jitter_threshold
97+
self.recognize_server_retry_delay = recognize_server_retry_delay
98+
self.max_server_retry_delay = max_server_retry_delay
6599
self._name = self.__class__.__name__
66100

67101
@property
@@ -98,7 +132,7 @@ def name(self) -> str:
98132
def can_retry(self, exception: BaseException) -> bool:
99133
return self.policy.can_retry(exception)
100134

101-
def next_attempt(self) -> Optional[int]:
135+
def next_attempt(self, exception: Optional[BaseException] = None) -> Optional[int]:
102136
"""
103137
Returns
104138
-------
@@ -119,6 +153,14 @@ def next_attempt(self) -> Optional[int]:
119153
float(self.policy.max_backoff), wait_time * self.policy.backoff_multiplier
120154
)
121155

156+
if exception is not None and self.policy.recognize_server_retry_delay:
157+
retry_delay = extract_retry_delay(exception)
158+
if retry_delay is not None:
159+
logger.debug(f"The server has sent a retry delay of {retry_delay} ms.")
160+
if self.policy.max_server_retry_delay is not None:
161+
retry_delay = min(retry_delay, self.policy.max_server_retry_delay)
162+
wait_time = max(wait_time, retry_delay)
163+
122164
# Jitter current backoff, after the future backoff was computed
123165
if wait_time >= self.policy.min_jitter_threshold:
124166
wait_time += random.uniform(0, self.policy.jitter)
@@ -160,6 +202,7 @@ class Retrying:
160202
This class is a point of entry into the retry logic.
161203
The class accepts a list of retry policies and applies them in given order.
162204
The first policy accepting an exception will be used.
205+
If the error was matched by one policy, the other policies will be skipped.
163206
164207
The usage of the class should be as follows:
165208
for attempt in Retrying(...):
@@ -217,17 +260,18 @@ def _wait(self) -> None:
217260
return
218261

219262
# Attempt to find a policy to wait with
263+
matched_policy = None
220264
for policy in self._policies:
221-
if not policy.can_retry(exception):
222-
continue
223-
224-
wait_time = policy.next_attempt()
265+
if policy.can_retry(exception):
266+
matched_policy = policy
267+
break
268+
if matched_policy is not None:
269+
wait_time = matched_policy.next_attempt(exception)
225270
if wait_time is not None:
226271
logger.debug(
227272
f"Got error: {repr(exception)}. "
228-
+ f"Will retry after {wait_time} ms (policy: {policy.name})"
273+
+ f"Will retry after {wait_time} ms (policy: {matched_policy.name})"
229274
)
230-
231275
self._sleep(wait_time / 1000)
232276
return
233277

@@ -274,6 +318,8 @@ def __init__(
274318
max_backoff: Optional[int] = 60000,
275319
jitter: int = 500,
276320
min_jitter_threshold: int = 2000,
321+
recognize_server_retry_delay: bool = True,
322+
max_server_retry_delay: Optional[int] = 10 * 60 * 1000, # 10 minutes
277323
):
278324
super().__init__(
279325
max_retries=max_retries,
@@ -282,6 +328,8 @@ def __init__(
282328
max_backoff=max_backoff,
283329
jitter=jitter,
284330
min_jitter_threshold=min_jitter_threshold,
331+
recognize_server_retry_delay=recognize_server_retry_delay,
332+
max_server_retry_delay=max_server_retry_delay,
285333
)
286334

287335
def can_retry(self, e: BaseException) -> bool:
@@ -314,4 +362,29 @@ def can_retry(self, e: BaseException) -> bool:
314362
if e.code() == grpc.StatusCode.UNAVAILABLE:
315363
return True
316364

365+
if extract_retry_info(e) is not None:
366+
# All errors messages containing `RetryInfo` should be retried.
367+
return True
368+
317369
return False
370+
371+
372+
def extract_retry_info(exception: BaseException) -> Optional[error_details_pb2.RetryInfo]:
373+
"""Extract and return RetryInfo from the grpc.RpcError"""
374+
if isinstance(exception, grpc.RpcError):
375+
status = rpc_status.from_call(cast(grpc.Call, exception))
376+
if status:
377+
for d in status.details:
378+
if d.Is(error_details_pb2.RetryInfo.DESCRIPTOR):
379+
info = error_details_pb2.RetryInfo()
380+
d.Unpack(info)
381+
return info
382+
return None
383+
384+
385+
def extract_retry_delay(exception: BaseException) -> Optional[int]:
386+
"""Extract and return RetryInfo.retry_delay in milliseconds from grpc.RpcError if present."""
387+
retry_info = extract_retry_info(exception)
388+
if retry_info is not None:
389+
return retry_info.retry_delay.ToMilliseconds()
390+
return None

python/pyspark/sql/tests/connect/client/test_client.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
)
3636
from pyspark.sql.connect.client.reattach import ExecutePlanResponseReattachableIterator
3737
from pyspark.sql.connect.session import SparkSession as RemoteSparkSession
38-
from pyspark.errors import PySparkRuntimeError, RetriesExceeded
38+
from pyspark.errors import PySparkRuntimeError
3939
import pyspark.sql.connect.proto as proto
4040

4141
class TestPolicy(DefaultPolicy):
@@ -227,35 +227,6 @@ def test_is_closed(self):
227227
client.close()
228228
self.assertTrue(client.is_closed)
229229

230-
def test_retry(self):
231-
client = SparkConnectClient("sc://foo/;token=bar")
232-
233-
total_sleep = 0
234-
235-
def sleep(t):
236-
nonlocal total_sleep
237-
total_sleep += t
238-
239-
try:
240-
for attempt in Retrying(client._retry_policies, sleep=sleep):
241-
with attempt:
242-
raise TestException("Retryable error", grpc.StatusCode.UNAVAILABLE)
243-
except RetriesExceeded:
244-
pass
245-
246-
# tolerated at least 10 mins of fails
247-
self.assertGreaterEqual(total_sleep, 600)
248-
249-
def test_retry_client_unit(self):
250-
client = SparkConnectClient("sc://foo/;token=bar")
251-
252-
policyA = TestPolicy()
253-
policyB = DefaultPolicy()
254-
255-
client.set_retry_policies([policyA, policyB])
256-
257-
self.assertEqual(client.get_retry_policies(), [policyA, policyB])
258-
259230
def test_channel_builder_with_session(self):
260231
dummy = str(uuid.uuid4())
261232
chan = DefaultChannelBuilder(f"sc://foo/;session_id={dummy}")

0 commit comments

Comments
 (0)