From cd258433c64420b6b24336dce242667c63a164e4 Mon Sep 17 00:00:00 2001 From: Michele Cardone Date: Fri, 18 Apr 2025 12:45:32 +0200 Subject: [PATCH] feat: add `ExponentialJitterBackoff` backoff strategy The new strategy is an extension of the default one that applies a jitter to avoid thundering herd. --- client.go | 30 ++++++++++ client_test.go | 152 ++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 181 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index efee53c..fc52a80 100644 --- a/client.go +++ b/client.go @@ -638,6 +638,36 @@ func LinearJitterBackoff(min, max time.Duration, attemptNum int, resp *http.Resp return time.Duration(jitterMin * int64(attemptNum)) } +// ExponentialJitterBackoff is an extension of DefaultBackoff that applies +// a jitter to avoid thundering herd. +func ExponentialJitterBackoff(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration { + baseBackoff := DefaultBackoff(min, max, attemptNum, resp) + + if resp != nil { + if retryAfterHeaders := resp.Header["Retry-After"]; len(retryAfterHeaders) > 0 && retryAfterHeaders[0] != "" { + return baseBackoff + } + } + + // Seed randomization; it's OK to do it every time + rnd := rand.New(rand.NewSource(time.Now().UnixNano())) + + jitter := rnd.Float64()*0.5 - 0.25 // Random value between -0.25 e +0.25 + jitteredSleep := time.Duration(float64(baseBackoff) * (1.0 + jitter)) + + return clampDuration(jitteredSleep, min, max) +} + +func clampDuration(d, min, max time.Duration) time.Duration { + if d < min { + return min + } + if d > max { + return max + } + return d +} + // PassthroughErrorHandler is an ErrorHandler that directly passes through the // values from the net/http library for the final request. The body is not // closed. diff --git a/client_test.go b/client_test.go index d12cd1b..82d3ae4 100644 --- a/client_test.go +++ b/client_test.go @@ -251,7 +251,7 @@ func testClientDo(t *testing.T, body interface{}) { } if resp.StatusCode != 200 { - t.Fatalf("exected 200, got: %d", resp.StatusCode) + t.Fatalf("expected 200, got: %d", resp.StatusCode) } if retryCount < 0 { @@ -896,6 +896,156 @@ func TestClient_DefaultBackoff(t *testing.T) { } } +func TestClient_ExponentialJitterBackoff(t *testing.T) { + const retriableStatusCode int = http.StatusServiceUnavailable + + t.Run("with non-empty first value of Retry-After header in response", func(t *testing.T) { + response := &http.Response{ + StatusCode: retriableStatusCode, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "Retry-After": []string{"42"}, + }, + } + backoff := ExponentialJitterBackoff(retryWaitMin, retryWaitMax, 3, response) + expectedBackoff := 42 * time.Second + + if backoff != expectedBackoff { + t.Fatalf("expected default backoff from Retry-After header (%s), got %s", expectedBackoff, backoff) + } + }) + + invalidRetryAfterHeaderCases := []struct { + name string + makeResponse func() *http.Response + }{ + { + name: "with empty first value of Retry-After header in response", + makeResponse: func() *http.Response { + return &http.Response{ + StatusCode: retriableStatusCode, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "Retry-After": []string{""}, + }, + } + }, + }, + { + name: "without Retry-After header in response", + makeResponse: func() *http.Response { + return &http.Response{ + StatusCode: retriableStatusCode, + Header: http.Header{"Content-Type": []string{"application/json"}}, + } + }, + }, + { + name: "with nil response", + makeResponse: func() *http.Response { + return nil + }, + }, + } + + for _, irahc := range invalidRetryAfterHeaderCases { + t.Run(irahc.name, func(t *testing.T) { + attemptNumCases := []struct { + name string + attemptNum int + expectedBackoffWithoutJitter time.Duration + }{ + { + name: "with first attempt", + attemptNum: 0, + expectedBackoffWithoutJitter: retryWaitMin, + }, + { + name: "with low attempt number", + attemptNum: 3, + expectedBackoffWithoutJitter: 16 * time.Second, + }, + { + name: "with high attempt number", + attemptNum: 10, + expectedBackoffWithoutJitter: retryWaitMax, + }, + } + + for _, anc := range attemptNumCases { + t.Run(anc.name, func(t *testing.T) { + backoff := ExponentialJitterBackoff(defaultRetryWaitMin, defaultRetryWaitMax, anc.attemptNum, irahc.makeResponse()) + expectedJitterDelta := float64(anc.expectedBackoffWithoutJitter) * 0.25 + expectedMinTime := anc.expectedBackoffWithoutJitter - time.Duration(expectedJitterDelta) + expectedMaxTime := anc.expectedBackoffWithoutJitter + time.Duration(expectedJitterDelta) + expectedBackoffLowerLimit := max(expectedMinTime, retryWaitMin) + expectedBackoffUpperLimit := min(expectedMaxTime, retryWaitMax) + + t.Run("returns exponential backoff with jitter, clamped within min and max limits", func(t *testing.T) { + if backoff < expectedBackoffLowerLimit || backoff > expectedBackoffUpperLimit { + t.Fatalf("expected backoff to be within range [%s, %s], got %s", expectedBackoffLowerLimit, expectedBackoffUpperLimit, backoff) + } + }) + }) + } + }) + } +} + +func Test_clampDuration(t *testing.T) { + const ( + minDuration time.Duration = 500 * time.Millisecond + maxDuration time.Duration = 10 * time.Minute + ) + + testCases := []struct { + name string + errorMessage string + duration time.Duration + expectedClampedDuration time.Duration + }{ + { + name: "with duration below min value", + errorMessage: "should return the min value", + duration: 60 * time.Microsecond, + expectedClampedDuration: minDuration, + }, + { + name: "with duration equal to min value", + errorMessage: "should return the min value", + duration: minDuration, + expectedClampedDuration: minDuration, + }, + { + name: "with duration strictly within min and max range", + errorMessage: "should return the given value", + duration: 45 * time.Second, + expectedClampedDuration: 45 * time.Second, + }, + { + name: "with duration equal to max value", + errorMessage: "should return the max value", + duration: maxDuration, + expectedClampedDuration: maxDuration, + }, + { + name: "with duration above max value", + errorMessage: "should return the max value", + duration: 2 * time.Hour, + expectedClampedDuration: maxDuration, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + duration := clampDuration(tc.duration, minDuration, maxDuration) + if duration != tc.expectedClampedDuration { + t.Fatalf("expected duration %s, got %s", expectedBackoff, backoff) + } + }) + } +} + func TestClient_DefaultRetryPolicy_TLS(t *testing.T) { ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200)