|
| 1 | +package fn |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "errors" |
| 6 | + "sync/atomic" |
| 7 | + "testing" |
| 8 | + "time" |
| 9 | + |
| 10 | + "github.com/stretchr/testify/require" |
| 11 | + "pgregory.net/rapid" |
| 12 | +) |
| 13 | + |
| 14 | +// TestRetryFuncNSuccessReturnsImmediately verifies that a successful function |
| 15 | +// returns immediately without any retries. |
| 16 | +func TestRetryFuncNSuccessReturnsImmediately(t *testing.T) { |
| 17 | + t.Parallel() |
| 18 | + |
| 19 | + rapid.Check(t, func(t *rapid.T) { |
| 20 | + // Generate a random retry config with reasonable bounds. |
| 21 | + config := RetryConfig{ |
| 22 | + MaxRetries: rapid.IntRange(1, 10).Draw(t, "maxRetries"), |
| 23 | + InitialBackoff: time.Duration( |
| 24 | + rapid.IntRange(1, 100).Draw( |
| 25 | + t, "initialBackoffMs", |
| 26 | + ), |
| 27 | + ) * time.Millisecond, |
| 28 | + BackoffMultiplier: rapid.Float64Range( |
| 29 | + 1.1, 3.0, |
| 30 | + ).Draw(t, "backoffMultiplier"), |
| 31 | + MaxBackoff: time.Duration( |
| 32 | + rapid.IntRange(100, 5000).Draw( |
| 33 | + t, "maxBackoffMs", |
| 34 | + ), |
| 35 | + ) * time.Millisecond, |
| 36 | + } |
| 37 | + |
| 38 | + // Generate a random value to return. |
| 39 | + expectedValue := rapid.Int().Draw(t, "expectedValue") |
| 40 | + |
| 41 | + // Track how many times the function is called. |
| 42 | + var callCount atomic.Int32 |
| 43 | + |
| 44 | + ctx := context.Background() |
| 45 | + start := time.Now() |
| 46 | + |
| 47 | + result, err := RetryFuncN(ctx, config, func() (int, error) { |
| 48 | + callCount.Add(1) |
| 49 | + return expectedValue, nil |
| 50 | + }) |
| 51 | + |
| 52 | + elapsed := time.Since(start) |
| 53 | + |
| 54 | + // The function should only be called once. |
| 55 | + require.Equal(t, int32(1), callCount.Load()) |
| 56 | + |
| 57 | + // No error should be returned. |
| 58 | + require.NoError(t, err) |
| 59 | + |
| 60 | + // The correct value should be returned. |
| 61 | + require.Equal(t, expectedValue, result) |
| 62 | + |
| 63 | + // The function should return almost immediately (allowing for |
| 64 | + // some execution overhead). |
| 65 | + require.Less(t, elapsed, 10*time.Millisecond) |
| 66 | + }) |
| 67 | +} |
| 68 | + |
| 69 | +// TestRetryFuncNRetriesExactlyMaxRetries verifies that a function that always |
| 70 | +// fails is retried exactly MaxRetries times. |
| 71 | +func TestRetryFuncNRetriesExactlyMaxRetries(t *testing.T) { |
| 72 | + t.Parallel() |
| 73 | + |
| 74 | + rapid.Check(t, func(t *rapid.T) { |
| 75 | + // Generate a random retry config. |
| 76 | + maxRetries := rapid.IntRange(0, 5).Draw(t, "maxRetries") |
| 77 | + config := RetryConfig{ |
| 78 | + MaxRetries: maxRetries, |
| 79 | + InitialBackoff: time.Duration( |
| 80 | + rapid.IntRange(1, 10).Draw( |
| 81 | + t, "initialBackoffMs", |
| 82 | + ), |
| 83 | + ) * time.Millisecond, |
| 84 | + BackoffMultiplier: rapid.Float64Range( |
| 85 | + 1.1, 2.0, |
| 86 | + ).Draw(t, "backoffMultiplier"), |
| 87 | + MaxBackoff: time.Duration( |
| 88 | + rapid.IntRange(50, 100).Draw(t, "maxBackoffMs"), |
| 89 | + ) * time.Millisecond, |
| 90 | + } |
| 91 | + |
| 92 | + // Track how many times the function is called. |
| 93 | + var callCount atomic.Int32 |
| 94 | + |
| 95 | + // Create a consistent error for all attempts. |
| 96 | + expectedErr := errors.New("persistent failure") |
| 97 | + |
| 98 | + ctx := context.Background() |
| 99 | + |
| 100 | + _, err := RetryFuncN(ctx, config, func() (int, error) { |
| 101 | + callCount.Add(1) |
| 102 | + return 0, expectedErr |
| 103 | + }) |
| 104 | + |
| 105 | + // The function should be called exactly MaxRetries + 1 times |
| 106 | + // (initial attempt + retries). |
| 107 | + require.Equal(t, int32(maxRetries+1), callCount.Load()) |
| 108 | + |
| 109 | + // The final error should be returned. |
| 110 | + require.Equal(t, expectedErr, err) |
| 111 | + }) |
| 112 | +} |
| 113 | + |
| 114 | +// TestRetryFuncNBackoffIncreases verifies that the backoff duration increases |
| 115 | +// exponentially between retries. |
| 116 | +func TestRetryFuncNBackoffIncreases(t *testing.T) { |
| 117 | + t.Parallel() |
| 118 | + |
| 119 | + rapid.Check(t, func(t *rapid.T) { |
| 120 | + // Generate retry config with at least 2 retries to observe |
| 121 | + // backoff behavior. |
| 122 | + config := RetryConfig{ |
| 123 | + MaxRetries: rapid.IntRange(2, 3).Draw( |
| 124 | + t, "maxRetries", |
| 125 | + ), |
| 126 | + InitialBackoff: time.Duration( |
| 127 | + // We use a slightly larger initial backoff |
| 128 | + // range here to avoid flakes on slow CI |
| 129 | + // machines where scheduling overhead can be |
| 130 | + // significant for very short sleep durations. |
| 131 | + // Reduced range for faster test execution. |
| 132 | + rapid.IntRange(10, 25).Draw( |
| 133 | + t, "initialBackoffMs", |
| 134 | + ), |
| 135 | + ) * time.Millisecond, |
| 136 | + BackoffMultiplier: rapid.Float64Range( |
| 137 | + 1.5, 2.5, |
| 138 | + ).Draw(t, "backoffMultiplier"), |
| 139 | + MaxBackoff: time.Duration( |
| 140 | + rapid.IntRange(50, 150).Draw( |
| 141 | + t, "maxBackoffMs", |
| 142 | + ), |
| 143 | + ) * time.Millisecond, |
| 144 | + } |
| 145 | + |
| 146 | + // Track call times to measure backoff. |
| 147 | + var callTimes []time.Time |
| 148 | + |
| 149 | + ctx := context.Background() |
| 150 | + |
| 151 | + _, err := RetryFuncN(ctx, config, func() (int, error) { |
| 152 | + callTimes = append(callTimes, time.Now()) |
| 153 | + return 0, errors.New("fail") |
| 154 | + }) |
| 155 | + |
| 156 | + require.Error(t, err) |
| 157 | + require.Len(t, callTimes, config.MaxRetries+1) |
| 158 | + |
| 159 | + expectedBackoff := config.InitialBackoff |
| 160 | + for i := 1; i < len(callTimes); i++ { |
| 161 | + actualBackoff := callTimes[i].Sub(callTimes[i-1]) |
| 162 | + |
| 163 | + // To avoid flakes, we use an asymmetric check for the |
| 164 | + // backoff duration. We expect the actual backoff to be |
| 165 | + // close to the expected backoff, but we allow for a |
| 166 | + // generous upper bound to account for scheduling delays |
| 167 | + // on busy systems. // The actual backoff should be |
| 168 | + // reasonably close to the expected backoff. We allow it |
| 169 | + // to be slightly shorter due to timer precision. |
| 170 | + lowerBound := time.Duration( |
| 171 | + float64(expectedBackoff) * 0.8, |
| 172 | + ) |
| 173 | + require.GreaterOrEqual(t, actualBackoff, lowerBound) |
| 174 | + |
| 175 | + // The actual backoff can be longer due to scheduling |
| 176 | + // delays. We allow a generous upper bound. |
| 177 | + upperBound := time.Duration( |
| 178 | + float64(expectedBackoff)*1.5, |
| 179 | + ) + 100*time.Millisecond |
| 180 | + require.LessOrEqual(t, actualBackoff, upperBound) |
| 181 | + |
| 182 | + // Calculate the next expected backoff, capping at |
| 183 | + // MaxBackoff. |
| 184 | + expectedBackoff = time.Duration( |
| 185 | + float64(expectedBackoff) * |
| 186 | + config.BackoffMultiplier, |
| 187 | + ) |
| 188 | + if expectedBackoff > config.MaxBackoff { |
| 189 | + expectedBackoff = config.MaxBackoff |
| 190 | + } |
| 191 | + } |
| 192 | + }) |
| 193 | +} |
| 194 | + |
| 195 | +// TestRetryFuncNContextCancellation verifies that context cancellation stops |
| 196 | +// the retry loop immediately. |
| 197 | +func TestRetryFuncNContextCancellation(t *testing.T) { |
| 198 | + t.Parallel() |
| 199 | + |
| 200 | + rapid.Check(t, func(t *rapid.T) { |
| 201 | + // Generate a retry config with shorter timeouts for faster test |
| 202 | + // execution. |
| 203 | + config := RetryConfig{ |
| 204 | + MaxRetries: rapid.IntRange(2, 5).Draw(t, "maxRetries"), |
| 205 | + InitialBackoff: time.Duration( |
| 206 | + rapid.IntRange(10, 50).Draw( |
| 207 | + t, "initialBackoffMs", |
| 208 | + ), |
| 209 | + ) * time.Millisecond, |
| 210 | + BackoffMultiplier: 1.5, |
| 211 | + MaxBackoff: 100 * time.Millisecond, |
| 212 | + } |
| 213 | + |
| 214 | + // Track how many times the function is called. |
| 215 | + var callCount atomic.Int32 |
| 216 | + |
| 217 | + // Cancel the context after the first attempt to ensure we |
| 218 | + // cancel during a backoff wait. |
| 219 | + ctx, cancel := context.WithCancel(context.Background()) |
| 220 | + |
| 221 | + // Schedule cancellation after a short delay. |
| 222 | + go func() { |
| 223 | + time.Sleep(5 * time.Millisecond) |
| 224 | + cancel() |
| 225 | + }() |
| 226 | + |
| 227 | + _, err := RetryFuncN(ctx, config, func() (int, error) { |
| 228 | + callCount.Add(1) |
| 229 | + return 0, errors.New("fail") |
| 230 | + }) |
| 231 | + |
| 232 | + // The error should be the context cancellation error. |
| 233 | + require.Equal(t, context.Canceled, err) |
| 234 | + |
| 235 | + // The function should have been called at least once but not |
| 236 | + // more than MaxRetries+1 times. |
| 237 | + calls := callCount.Load() |
| 238 | + require.GreaterOrEqual(t, calls, int32(1)) |
| 239 | + require.LessOrEqual(t, calls, int32(config.MaxRetries+1)) |
| 240 | + }) |
| 241 | +} |
| 242 | + |
| 243 | +// TestRetryFuncNEventualSuccess verifies that if a function succeeds after some |
| 244 | +// failures, the correct result is returned. |
| 245 | +func TestRetryFuncNEventualSuccess(t *testing.T) { |
| 246 | + t.Parallel() |
| 247 | + |
| 248 | + rapid.Check(t, func(t *rapid.T) { |
| 249 | + config := RetryConfig{ |
| 250 | + MaxRetries: rapid.IntRange(3, 10).Draw(t, "maxRetries"), |
| 251 | + InitialBackoff: time.Duration( |
| 252 | + rapid.IntRange(1, 10).Draw( |
| 253 | + t, "initialBackoffMs", |
| 254 | + ), |
| 255 | + ) * time.Millisecond, |
| 256 | + BackoffMultiplier: 2.0, |
| 257 | + MaxBackoff: 50 * time.Millisecond, |
| 258 | + } |
| 259 | + |
| 260 | + // Determine after how many attempts the function should |
| 261 | + // succeed. |
| 262 | + succeedAfter := rapid.IntRange( |
| 263 | + 1, config.MaxRetries+1, |
| 264 | + ).Draw(t, "succeedAfter") |
| 265 | + |
| 266 | + expectedValue := rapid.Int().Draw(t, "expectedValue") |
| 267 | + |
| 268 | + // Track how many times the function is called. |
| 269 | + var callCount atomic.Int32 |
| 270 | + |
| 271 | + ctx := context.Background() |
| 272 | + |
| 273 | + result, err := RetryFuncN(ctx, config, func() (int, error) { |
| 274 | + count := callCount.Add(1) |
| 275 | + if int(count) >= succeedAfter { |
| 276 | + return expectedValue, nil |
| 277 | + } |
| 278 | + return 0, errors.New("temporary failure") |
| 279 | + }) |
| 280 | + |
| 281 | + // The function should succeed. |
| 282 | + require.NoError(t, err) |
| 283 | + require.Equal(t, expectedValue, result) |
| 284 | + |
| 285 | + // The function should be called exactly succeedAfter times. |
| 286 | + require.Equal(t, int32(succeedAfter), callCount.Load()) |
| 287 | + }) |
| 288 | +} |
0 commit comments