Skip to content

Commit d5541b5

Browse files
committed
scala
1 parent c979b5f commit d5541b5

File tree

4 files changed

+352
-132
lines changed

4 files changed

+352
-132
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.connect.client
18+
19+
import scala.concurrent.duration.FiniteDuration
20+
21+
import com.google.protobuf.{Any, Duration}
22+
import com.google.rpc
23+
import io.grpc.{Status, StatusRuntimeException}
24+
import io.grpc.protobuf.StatusProto
25+
import org.scalatest.BeforeAndAfterEach
26+
import org.scalatest.concurrent.Eventually
27+
28+
import org.apache.spark.sql.connect.test.ConnectFunSuite
29+
30+
class SparkConnectClientRetriesSuite
31+
extends ConnectFunSuite
32+
with BeforeAndAfterEach
33+
with Eventually {
34+
35+
private class DummyFn(e: => Throwable, numFails: Int = 3) {
36+
var counter = 0
37+
def fn(): Int = {
38+
if (counter < numFails) {
39+
counter += 1
40+
throw e
41+
} else {
42+
42
43+
}
44+
}
45+
}
46+
47+
/** Tracks sleep times in milliseconds for testing purposes. */
48+
private class SleepTimeTracker {
49+
private val data = scala.collection.mutable.ListBuffer[Long]()
50+
def sleep(t: Long): Unit = data.append(t)
51+
def times: List[Long] = data.toList
52+
def totalSleep: Long = data.sum
53+
}
54+
55+
/** Helper function for creating a test exception with retry_delay */
56+
private def createTestExceptionWithDetails(
57+
msg: String,
58+
code: Status.Code = Status.Code.INTERNAL,
59+
retryDelay: FiniteDuration = FiniteDuration(0, "s")): StatusRuntimeException = {
60+
// In grpc-java, RetryDelay should be specified as seconds: Long + nanos: Int
61+
val seconds = retryDelay.toSeconds
62+
val nanos = (retryDelay - FiniteDuration(seconds, "s")).toNanos.toInt
63+
val retryDelayMsg = Duration
64+
.newBuilder()
65+
.setSeconds(seconds)
66+
.setNanos(nanos)
67+
.build()
68+
val retryInfo = rpc.RetryInfo
69+
.newBuilder()
70+
.setRetryDelay(retryDelayMsg)
71+
.build()
72+
val status = rpc.Status
73+
.newBuilder()
74+
.setMessage(msg)
75+
.setCode(code.value())
76+
.addDetails(Any.pack(retryInfo))
77+
.build()
78+
StatusProto.toStatusRuntimeException(status)
79+
}
80+
81+
/** helper function for comparing two sequences of sleep times */
82+
private def assertLongSequencesAlmostEqual(
83+
first: Seq[Long],
84+
second: Seq[Long],
85+
delta: Long): Unit = {
86+
assert(first.length == second.length, "Lists have different lengths.")
87+
for ((a, b) <- first.zip(second)) {
88+
assert(math.abs(a - b) <= delta, s"Elements $a and $b differ by more than $delta.")
89+
}
90+
}
91+
92+
test("SPARK-44721: Retries run for a minimum period") {
93+
// repeat test few times to avoid random flakes
94+
for (_ <- 1 to 10) {
95+
val st = new SleepTimeTracker()
96+
val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE), numFails = 100)
97+
val retryHandler = new GrpcRetryHandler(RetryPolicy.defaultPolicies(), st.sleep)
98+
99+
assertThrows[RetriesExceeded] {
100+
retryHandler.retry {
101+
dummyFn.fn()
102+
}
103+
}
104+
105+
assert(st.totalSleep >= 10 * 60 * 1000) // waited at least 10 minutes
106+
}
107+
}
108+
109+
test("SPARK-44275: retry actually retries") {
110+
val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE))
111+
val retryPolicies = RetryPolicy.defaultPolicies()
112+
val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = _ => {})
113+
val result = retryHandler.retry { dummyFn.fn() }
114+
115+
assert(result == 42)
116+
assert(dummyFn.counter == 3)
117+
}
118+
119+
test("SPARK-44275: default retryException retries only on UNAVAILABLE") {
120+
val dummyFn = new DummyFn(new StatusRuntimeException(Status.ABORTED))
121+
val retryPolicies = RetryPolicy.defaultPolicies()
122+
val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = _ => {})
123+
124+
assertThrows[StatusRuntimeException] {
125+
retryHandler.retry { dummyFn.fn() }
126+
}
127+
assert(dummyFn.counter == 1)
128+
}
129+
130+
test("SPARK-44275: retry uses canRetry to filter exceptions") {
131+
val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE))
132+
val retryPolicy = RetryPolicy(canRetry = _ => false, name = "TestPolicy")
133+
val retryHandler = new GrpcRetryHandler(retryPolicy)
134+
135+
assertThrows[StatusRuntimeException] {
136+
retryHandler.retry { dummyFn.fn() }
137+
}
138+
assert(dummyFn.counter == 1)
139+
}
140+
141+
test("SPARK-44275: retry does not exceed maxRetries") {
142+
val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE))
143+
val retryPolicy = RetryPolicy(canRetry = _ => true, maxRetries = Some(1), name = "TestPolicy")
144+
val retryHandler = new GrpcRetryHandler(retryPolicy, sleep = _ => {})
145+
146+
assertThrows[RetriesExceeded] {
147+
retryHandler.retry { dummyFn.fn() }
148+
}
149+
assert(dummyFn.counter == 2)
150+
}
151+
152+
def testPolicySpecificError(maxRetries: Int, status: Status): RetryPolicy = {
153+
RetryPolicy(
154+
maxRetries = Some(maxRetries),
155+
name = s"Policy for ${status.getCode}",
156+
canRetry = {
157+
case e: StatusRuntimeException => e.getStatus.getCode == status.getCode
158+
case _ => false
159+
})
160+
}
161+
162+
test("Test multiple policies") {
163+
val policy1 = testPolicySpecificError(maxRetries = 2, status = Status.UNAVAILABLE)
164+
val policy2 = testPolicySpecificError(maxRetries = 4, status = Status.INTERNAL)
165+
166+
// Tolerate 2 UNAVAILABLE errors and 4 INTERNAL errors
167+
168+
val errors = (List.fill(2)(Status.UNAVAILABLE) ++ List.fill(4)(Status.INTERNAL)).iterator
169+
170+
new GrpcRetryHandler(List(policy1, policy2), sleep = _ => {}).retry({
171+
val e = errors.nextOption()
172+
if (e.isDefined) {
173+
throw e.get.asRuntimeException()
174+
}
175+
})
176+
177+
assert(!errors.hasNext)
178+
}
179+
180+
test("Test multiple policies exceed") {
181+
val policy1 = testPolicySpecificError(maxRetries = 2, status = Status.INTERNAL)
182+
val policy2 = testPolicySpecificError(maxRetries = 4, status = Status.INTERNAL)
183+
184+
val errors = List.fill(10)(Status.INTERNAL).iterator
185+
var countAttempted = 0
186+
187+
assertThrows[RetriesExceeded](
188+
new GrpcRetryHandler(List(policy1, policy2), sleep = _ => {}).retry({
189+
countAttempted += 1
190+
val e = errors.nextOption()
191+
if (e.isDefined) {
192+
throw e.get.asRuntimeException()
193+
}
194+
}))
195+
196+
assert(countAttempted == 3)
197+
}
198+
test("DefaultPolicy retries exceptions with RetryInfo") {
199+
// Error contains RetryInfo with retry_delay set to 0
200+
val dummyFn =
201+
new DummyFn(createTestExceptionWithDetails(msg = "Some error message"), numFails = 100)
202+
val retryPolicies = RetryPolicy.defaultPolicies()
203+
val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = _ => {})
204+
assertThrows[RetriesExceeded] {
205+
retryHandler.retry { dummyFn.fn() }
206+
}
207+
208+
// Should be retried by DefaultPolicy
209+
val policy = retryPolicies.find(_.name == "DefaultPolicy").get
210+
assert(dummyFn.counter == policy.maxRetries.get + 1)
211+
}
212+
213+
test("retry_delay overrides maxBackoff") {
214+
val st = new SleepTimeTracker()
215+
val retryDelay = FiniteDuration(5, "min")
216+
val dummyFn = new DummyFn(
217+
createTestExceptionWithDetails(msg = "Some error message", retryDelay = retryDelay),
218+
numFails = 100)
219+
val retryPolicies = RetryPolicy.defaultPolicies()
220+
val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = st.sleep)
221+
222+
assertThrows[RetriesExceeded] {
223+
retryHandler.retry { dummyFn.fn() }
224+
}
225+
226+
// Should be retried by DefaultPolicy
227+
val policy = retryPolicies.find(_.name == "DefaultPolicy").get
228+
// sleep times are higher than maxBackoff and are equal to retryDelay + jitter
229+
st.times.foreach(t => assert(t > policy.maxBackoff.get.toMillis + policy.jitter.toMillis))
230+
val expectedSleeps = List.fill(policy.maxRetries.get)(retryDelay.toMillis)
231+
assertLongSequencesAlmostEqual(st.times, expectedSleeps, policy.jitter.toMillis)
232+
}
233+
234+
test("maxServerRetryDelay limits retry_delay") {
235+
val st = new SleepTimeTracker()
236+
val retryDelay = FiniteDuration(5, "d")
237+
val dummyFn = new DummyFn(
238+
createTestExceptionWithDetails(msg = "Some error message", retryDelay = retryDelay),
239+
numFails = 100)
240+
val retryPolicies = RetryPolicy.defaultPolicies()
241+
val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = st.sleep)
242+
243+
assertThrows[RetriesExceeded] {
244+
retryHandler.retry { dummyFn.fn() }
245+
}
246+
247+
// Should be retried by DefaultPolicy
248+
val policy = retryPolicies.find(_.name == "DefaultPolicy").get
249+
val expectedSleeps = List.fill(policy.maxRetries.get)(policy.maxServerRetryDelay.get.toMillis)
250+
assertLongSequencesAlmostEqual(st.times, expectedSleeps, policy.jitter.toMillis)
251+
}
252+
253+
test("Policy uses to exponential backoff after retry_delay is unset") {
254+
val st = new SleepTimeTracker()
255+
val retryDelay = FiniteDuration(5, "min")
256+
val retryPolicies = RetryPolicy.defaultPolicies()
257+
val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = st.sleep)
258+
val errors = (
259+
List.fill(2)(
260+
createTestExceptionWithDetails(
261+
msg = "Some error message",
262+
retryDelay = retryDelay)) ++ List.fill(3)(
263+
createTestExceptionWithDetails(
264+
msg = "Some error message",
265+
code = Status.Code.UNAVAILABLE))
266+
).iterator
267+
268+
retryHandler.retry({
269+
if (errors.hasNext) {
270+
throw errors.next()
271+
}
272+
})
273+
assert(!errors.hasNext)
274+
275+
// Should be retried by DefaultPolicy
276+
val policy = retryPolicies.find(_.name == "DefaultPolicy").get
277+
val expectedSleeps = List.fill(2)(retryDelay.toMillis) ++ List.tabulate(3)(i =>
278+
policy.initialBackoff.toMillis * math.pow(policy.backoffMultiplier, i + 2).toLong)
279+
assertLongSequencesAlmostEqual(st.times, expectedSleeps, delta = policy.jitter.toMillis)
280+
}
281+
}

0 commit comments

Comments
 (0)