Skip to content

Commit 99e7497

Browse files
committed
scala
1 parent c6ab96a commit 99e7497

File tree

4 files changed

+373
-132
lines changed

4 files changed

+373
-132
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.connect.client
18+
19+
import scala.concurrent.duration.FiniteDuration
20+
21+
import com.google.protobuf.{Any, Duration}
22+
import com.google.rpc
23+
import io.grpc.{Status, StatusRuntimeException}
24+
import io.grpc.protobuf.StatusProto
25+
import org.scalatest.BeforeAndAfterEach
26+
import org.scalatest.concurrent.Eventually
27+
28+
import org.apache.spark.sql.connect.test.ConnectFunSuite
29+
30+
class SparkConnectClientRetriesSuite
31+
extends ConnectFunSuite
32+
with BeforeAndAfterEach
33+
with Eventually {
34+
35+
private class DummyFn(e: => Throwable, numFails: Int = 3) {
36+
var counter = 0
37+
def fn(): Int = {
38+
if (counter < numFails) {
39+
counter += 1
40+
throw e
41+
} else {
42+
42
43+
}
44+
}
45+
}
46+
47+
/** Tracks sleep times in milliseconds for testing purposes. */
48+
private class SleepTimeTracker {
49+
private val data = scala.collection.mutable.ListBuffer[Long]()
50+
def sleep(t: Long): Unit = data.append(t)
51+
def times: List[Long] = data.toList
52+
def totalSleep: Long = data.sum
53+
}
54+
55+
/** Helper function for creating a test exception with retry_delay */
56+
private def createTestExceptionWithDetails(
57+
msg: String,
58+
code: Status.Code = Status.Code.INTERNAL,
59+
retryDelay: FiniteDuration = FiniteDuration(0, "s")
60+
): StatusRuntimeException = {
61+
// In grpc-java, RetryDelay should be specified as seconds: Long + nanos: Int
62+
val seconds = retryDelay.toSeconds
63+
val nanos = (retryDelay - FiniteDuration(seconds, "s")).toNanos.toInt
64+
val retryDelayMsg = Duration
65+
.newBuilder()
66+
.setSeconds(seconds)
67+
.setNanos(nanos)
68+
.build()
69+
val retryInfo = rpc.RetryInfo
70+
.newBuilder()
71+
.setRetryDelay(retryDelayMsg)
72+
.build()
73+
val status = rpc.Status
74+
.newBuilder()
75+
.setMessage(msg)
76+
.setCode(code.value())
77+
.addDetails(Any.pack(retryInfo))
78+
.build()
79+
StatusProto.toStatusRuntimeException(status)
80+
}
81+
82+
/** helper function for comparing two sequences of sleep times */
83+
private def assertLongSequencesAlmostEqual(
84+
first: Seq[Long],
85+
second: Seq[Long],
86+
delta: Long
87+
): Unit = {
88+
assert(first.length == second.length, "Lists have different lengths.")
89+
for ((a, b) <- first.zip(second)) {
90+
assert(math.abs(a - b) <= delta, s"Elements $a and $b differ by more than $delta.")
91+
}
92+
}
93+
94+
test("SPARK-44721: Retries run for a minimum period") {
95+
// repeat test few times to avoid random flakes
96+
for (_ <- 1 to 10) {
97+
val st = new SleepTimeTracker()
98+
val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE), numFails = 100)
99+
val retryHandler = new GrpcRetryHandler(RetryPolicy.defaultPolicies(), st.sleep)
100+
101+
assertThrows[RetriesExceeded] {
102+
retryHandler.retry {
103+
dummyFn.fn()
104+
}
105+
}
106+
107+
assert(st.totalSleep >= 10 * 60 * 1000) // waited at least 10 minutes
108+
}
109+
}
110+
111+
test("SPARK-44275: retry actually retries") {
112+
val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE))
113+
val retryPolicies = RetryPolicy.defaultPolicies()
114+
val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = _ => {})
115+
val result = retryHandler.retry { dummyFn.fn() }
116+
117+
assert(result == 42)
118+
assert(dummyFn.counter == 3)
119+
}
120+
121+
test("SPARK-44275: default retryException retries only on UNAVAILABLE") {
122+
val dummyFn = new DummyFn(new StatusRuntimeException(Status.ABORTED))
123+
val retryPolicies = RetryPolicy.defaultPolicies()
124+
val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = _ => {})
125+
126+
assertThrows[StatusRuntimeException] {
127+
retryHandler.retry { dummyFn.fn() }
128+
}
129+
assert(dummyFn.counter == 1)
130+
}
131+
132+
test("SPARK-44275: retry uses canRetry to filter exceptions") {
133+
val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE))
134+
val retryPolicy = RetryPolicy(canRetry = _ => false, name = "TestPolicy")
135+
val retryHandler = new GrpcRetryHandler(retryPolicy)
136+
137+
assertThrows[StatusRuntimeException] {
138+
retryHandler.retry { dummyFn.fn() }
139+
}
140+
assert(dummyFn.counter == 1)
141+
}
142+
143+
test("SPARK-44275: retry does not exceed maxRetries") {
144+
val dummyFn = new DummyFn(new StatusRuntimeException(Status.UNAVAILABLE))
145+
val retryPolicy = RetryPolicy(canRetry = _ => true, maxRetries = Some(1), name = "TestPolicy")
146+
val retryHandler = new GrpcRetryHandler(retryPolicy, sleep = _ => {})
147+
148+
assertThrows[RetriesExceeded] {
149+
retryHandler.retry { dummyFn.fn() }
150+
}
151+
assert(dummyFn.counter == 2)
152+
}
153+
154+
def testPolicySpecificError(maxRetries: Int, status: Status): RetryPolicy = {
155+
RetryPolicy(
156+
maxRetries = Some(maxRetries),
157+
name = s"Policy for ${status.getCode}",
158+
canRetry = {
159+
case e: StatusRuntimeException => e.getStatus.getCode == status.getCode
160+
case _ => false
161+
})
162+
}
163+
164+
test("Test multiple policies") {
165+
val policy1 = testPolicySpecificError(maxRetries = 2, status = Status.UNAVAILABLE)
166+
val policy2 = testPolicySpecificError(maxRetries = 4, status = Status.INTERNAL)
167+
168+
// Tolerate 2 UNAVAILABLE errors and 4 INTERNAL errors
169+
170+
val errors = (List.fill(2)(Status.UNAVAILABLE) ++ List.fill(4)(Status.INTERNAL)).iterator
171+
172+
new GrpcRetryHandler(List(policy1, policy2), sleep = _ => {}).retry({
173+
val e = errors.nextOption()
174+
if (e.isDefined) {
175+
throw e.get.asRuntimeException()
176+
}
177+
})
178+
179+
assert(!errors.hasNext)
180+
}
181+
182+
test("Test multiple policies exceed") {
183+
val policy1 = testPolicySpecificError(maxRetries = 2, status = Status.INTERNAL)
184+
val policy2 = testPolicySpecificError(maxRetries = 4, status = Status.INTERNAL)
185+
186+
val errors = List.fill(10)(Status.INTERNAL).iterator
187+
var countAttempted = 0
188+
189+
assertThrows[RetriesExceeded](
190+
new GrpcRetryHandler(List(policy1, policy2), sleep = _ => {}).retry({
191+
countAttempted += 1
192+
val e = errors.nextOption()
193+
if (e.isDefined) {
194+
throw e.get.asRuntimeException()
195+
}
196+
}))
197+
198+
assert(countAttempted == 7)
199+
}
200+
test("DefaultPolicy retries exceptions with RetryInfo") {
201+
// Error contains RetryInfo with retry_delay set to 0
202+
val dummyFn = new DummyFn(
203+
createTestExceptionWithDetails(msg = "Some error message"),
204+
numFails = 100
205+
)
206+
val retryPolicies = RetryPolicy.defaultPolicies()
207+
val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = _ => {})
208+
assertThrows[RetriesExceeded] {
209+
retryHandler.retry { dummyFn.fn() }
210+
}
211+
212+
// Should be retried by DefaultPolicy
213+
val policy = retryPolicies.find(_.name == "DefaultPolicy").get
214+
assert(dummyFn.counter == policy.maxRetries.get + 1)
215+
}
216+
217+
test("retry_delay overrides maxBackoff") {
218+
val st = new SleepTimeTracker()
219+
val retryDelay = FiniteDuration(5, "min")
220+
val dummyFn = new DummyFn(
221+
createTestExceptionWithDetails(
222+
msg = "Some error message",
223+
retryDelay = retryDelay
224+
),
225+
numFails = 100
226+
)
227+
val retryPolicies = RetryPolicy.defaultPolicies()
228+
val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = st.sleep)
229+
230+
assertThrows[RetriesExceeded] {
231+
retryHandler.retry { dummyFn.fn() }
232+
}
233+
234+
// Should be retried by DefaultPolicy
235+
val policy = retryPolicies.find(_.name == "DefaultPolicy").get
236+
// sleep times are higher than maxBackoff and are equal to retryDelay + jitter
237+
st.times.foreach(t => assert(t > policy.maxBackoff.get.toMillis + policy.jitter.toMillis))
238+
val expectedSleeps = List.fill(policy.maxRetries.get)(retryDelay.toMillis)
239+
assertLongSequencesAlmostEqual(st.times, expectedSleeps, policy.jitter.toMillis)
240+
}
241+
242+
test("maxServerRetryDelay limits retry_delay") {
243+
val st = new SleepTimeTracker()
244+
val retryDelay = FiniteDuration(5, "d")
245+
val dummyFn = new DummyFn(
246+
createTestExceptionWithDetails(
247+
msg = "Some error message",
248+
retryDelay = retryDelay
249+
),
250+
numFails = 100
251+
)
252+
val retryPolicies = RetryPolicy.defaultPolicies()
253+
val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = st.sleep)
254+
255+
assertThrows[RetriesExceeded] {
256+
retryHandler.retry { dummyFn.fn() }
257+
}
258+
259+
// Should be retried by DefaultPolicy
260+
val policy = retryPolicies.find(_.name == "DefaultPolicy").get
261+
val expectedSleeps = List.fill(policy.maxRetries.get)(policy.maxServerRetryDelay.get.toMillis)
262+
assertLongSequencesAlmostEqual(st.times, expectedSleeps, policy.jitter.toMillis)
263+
}
264+
265+
test("Policy uses to exponential backoff after retry_delay is unset") {
266+
val st = new SleepTimeTracker()
267+
val retryDelay = FiniteDuration(5, "min")
268+
val retryPolicies = RetryPolicy.defaultPolicies()
269+
val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = st.sleep)
270+
val errors = (
271+
List.fill(2)(
272+
createTestExceptionWithDetails(
273+
msg = "Some error message",
274+
retryDelay = retryDelay
275+
)
276+
) ++ List.fill(3)(
277+
createTestExceptionWithDetails(
278+
msg = "Some error message",
279+
code = Status.Code.UNAVAILABLE
280+
)
281+
)
282+
).iterator
283+
284+
retryHandler.retry({
285+
if (errors.hasNext) {
286+
throw errors.next()
287+
}
288+
})
289+
assert(!errors.hasNext)
290+
291+
// Should be retried by DefaultPolicy
292+
val policy = retryPolicies.find(_.name == "DefaultPolicy").get
293+
val expectedSleeps = List.fill(2)(retryDelay.toMillis) ++ List.tabulate(3)(
294+
i => policy.initialBackoff.toMillis * math.pow(policy.backoffMultiplier, i + 2).toLong
295+
)
296+
assertLongSequencesAlmostEqual(st.times, expectedSleeps, delta = policy.jitter.toMillis)
297+
}
298+
}

0 commit comments

Comments
 (0)