|
| 1 | +/* |
| 2 | + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 3 | + * SPDX-License-Identifier: Apache-2.0 |
| 4 | + */ |
| 5 | + |
| 6 | +use aws_smithy_runtime_api::client::interceptors::InterceptorContext; |
| 7 | +use aws_smithy_runtime_api::client::orchestrator::{BoxError, ConfigBagAccessors}; |
| 8 | +use aws_smithy_runtime_api::client::request_attempts::RequestAttempts; |
| 9 | +use aws_smithy_runtime_api::client::retries::{ |
| 10 | + ClassifyRetry, RetryReason, RetryStrategy, ShouldAttempt, |
| 11 | +}; |
| 12 | +use aws_smithy_types::config_bag::ConfigBag; |
| 13 | +use aws_smithy_types::retry::RetryConfig; |
| 14 | +use std::time::Duration; |
| 15 | + |
| 16 | +const DEFAULT_MAX_ATTEMPTS: usize = 4; |
| 17 | + |
| 18 | +#[derive(Debug)] |
| 19 | +pub struct StandardRetryStrategy { |
| 20 | + max_attempts: usize, |
| 21 | + initial_backoff: Duration, |
| 22 | + max_backoff: Duration, |
| 23 | + base: fn() -> f64, |
| 24 | +} |
| 25 | + |
| 26 | +impl StandardRetryStrategy { |
| 27 | + pub fn new(retry_config: &RetryConfig) -> Self { |
| 28 | + // TODO(enableNewSmithyRuntime) add support for `retry_config.reconnect_mode()` here or in the orchestrator flow. |
| 29 | + Self::default() |
| 30 | + .with_max_attempts(retry_config.max_attempts() as usize) |
| 31 | + .with_initial_backoff(retry_config.initial_backoff()) |
| 32 | + } |
| 33 | + |
| 34 | + pub fn with_base(mut self, base: fn() -> f64) -> Self { |
| 35 | + self.base = base; |
| 36 | + self |
| 37 | + } |
| 38 | + |
| 39 | + pub fn with_max_attempts(mut self, max_attempts: usize) -> Self { |
| 40 | + self.max_attempts = max_attempts; |
| 41 | + self |
| 42 | + } |
| 43 | + |
| 44 | + pub fn with_initial_backoff(mut self, initial_backoff: Duration) -> Self { |
| 45 | + self.initial_backoff = initial_backoff; |
| 46 | + self |
| 47 | + } |
| 48 | +} |
| 49 | + |
| 50 | +impl Default for StandardRetryStrategy { |
| 51 | + fn default() -> Self { |
| 52 | + Self { |
| 53 | + max_attempts: DEFAULT_MAX_ATTEMPTS, |
| 54 | + max_backoff: Duration::from_secs(20), |
| 55 | + // by default, use a random base for exponential backoff |
| 56 | + base: fastrand::f64, |
| 57 | + initial_backoff: Duration::from_secs(1), |
| 58 | + } |
| 59 | + } |
| 60 | +} |
| 61 | + |
| 62 | +impl RetryStrategy for StandardRetryStrategy { |
| 63 | + // TODO(token-bucket) add support for optional cross-request token bucket |
| 64 | + fn should_attempt_initial_request(&self, _cfg: &ConfigBag) -> Result<ShouldAttempt, BoxError> { |
| 65 | + Ok(ShouldAttempt::Yes) |
| 66 | + } |
| 67 | + |
| 68 | + fn should_attempt_retry( |
| 69 | + &self, |
| 70 | + ctx: &InterceptorContext, |
| 71 | + cfg: &ConfigBag, |
| 72 | + ) -> Result<ShouldAttempt, BoxError> { |
| 73 | + // Look a the result. If it's OK then we're done; No retry required. Otherwise, we need to inspect it |
| 74 | + let output_or_error = ctx.output_or_error().expect( |
| 75 | + "This must never be called without reaching the point where the result exists.", |
| 76 | + ); |
| 77 | + if output_or_error.is_ok() { |
| 78 | + tracing::debug!("request succeeded, no retry necessary"); |
| 79 | + return Ok(ShouldAttempt::No); |
| 80 | + } |
| 81 | + |
| 82 | + // Check if we're out of attempts |
| 83 | + let request_attempts: &RequestAttempts = cfg |
| 84 | + .get() |
| 85 | + .expect("at least one request attempt is made before any retry is attempted"); |
| 86 | + if request_attempts.attempts() >= self.max_attempts { |
| 87 | + tracing::trace!( |
| 88 | + attempts = request_attempts.attempts(), |
| 89 | + max_attempts = self.max_attempts, |
| 90 | + "not retrying because we are out of attempts" |
| 91 | + ); |
| 92 | + return Ok(ShouldAttempt::No); |
| 93 | + } |
| 94 | + |
| 95 | + // Run the classifiers against the context to determine if we should retry |
| 96 | + let retry_classifiers = cfg.retry_classifiers(); |
| 97 | + let retry_reason = retry_classifiers.classify_retry(ctx); |
| 98 | + let backoff = match retry_reason { |
| 99 | + Some(RetryReason::Explicit(dur)) => dur, |
| 100 | + Some(RetryReason::Error(_)) => { |
| 101 | + let backoff = calculate_exponential_backoff( |
| 102 | + // Generate a random base multiplier to create jitter |
| 103 | + (self.base)(), |
| 104 | + // Get the backoff time multiplier in seconds (with fractional seconds) |
| 105 | + self.initial_backoff.as_secs_f64(), |
| 106 | + // `self.local.attempts` tracks number of requests made including the initial request |
| 107 | + // The initial attempt shouldn't count towards backoff calculations so we subtract it |
| 108 | + (request_attempts.attempts() - 1) as u32, |
| 109 | + ); |
| 110 | + Duration::from_secs_f64(backoff).min(self.max_backoff) |
| 111 | + } |
| 112 | + Some(_) => { |
| 113 | + unreachable!("RetryReason is non-exhaustive. Therefore, we need to cover this unreachable case.") |
| 114 | + } |
| 115 | + None => { |
| 116 | + tracing::trace!( |
| 117 | + attempts = request_attempts.attempts(), |
| 118 | + max_attempts = self.max_attempts, |
| 119 | + "encountered unretryable error" |
| 120 | + ); |
| 121 | + return Ok(ShouldAttempt::No); |
| 122 | + } |
| 123 | + }; |
| 124 | + |
| 125 | + tracing::debug!( |
| 126 | + "attempt {} failed with {:?}; retrying after {:?}", |
| 127 | + request_attempts.attempts(), |
| 128 | + retry_reason.expect("the match statement above ensures this is not None"), |
| 129 | + backoff |
| 130 | + ); |
| 131 | + |
| 132 | + Ok(ShouldAttempt::YesAfterDelay(backoff)) |
| 133 | + } |
| 134 | +} |
| 135 | + |
| 136 | +fn calculate_exponential_backoff(base: f64, initial_backoff: f64, retry_attempts: u32) -> f64 { |
| 137 | + base * initial_backoff * 2_u32.pow(retry_attempts) as f64 |
| 138 | +} |
| 139 | + |
| 140 | +#[cfg(test)] |
| 141 | +mod tests { |
| 142 | + use super::{ShouldAttempt, StandardRetryStrategy}; |
| 143 | + use aws_smithy_runtime_api::client::interceptors::InterceptorContext; |
| 144 | + use aws_smithy_runtime_api::client::orchestrator::{ConfigBagAccessors, OrchestratorError}; |
| 145 | + use aws_smithy_runtime_api::client::request_attempts::RequestAttempts; |
| 146 | + use aws_smithy_runtime_api::client::retries::{AlwaysRetry, RetryClassifiers, RetryStrategy}; |
| 147 | + use aws_smithy_types::config_bag::ConfigBag; |
| 148 | + use aws_smithy_types::retry::ErrorKind; |
| 149 | + use aws_smithy_types::type_erasure::TypeErasedBox; |
| 150 | + use std::time::Duration; |
| 151 | + |
| 152 | + #[test] |
| 153 | + fn no_retry_necessary_for_ok_result() { |
| 154 | + let cfg = ConfigBag::base(); |
| 155 | + let mut ctx = InterceptorContext::new(TypeErasedBox::doesnt_matter()); |
| 156 | + let strategy = StandardRetryStrategy::default(); |
| 157 | + ctx.set_output_or_error(Ok(TypeErasedBox::doesnt_matter())); |
| 158 | + let actual = strategy |
| 159 | + .should_attempt_retry(&ctx, &cfg) |
| 160 | + .expect("method is infallible for this use"); |
| 161 | + assert_eq!(ShouldAttempt::No, actual); |
| 162 | + } |
| 163 | + |
| 164 | + fn set_up_cfg_and_context( |
| 165 | + error_kind: ErrorKind, |
| 166 | + current_request_attempts: usize, |
| 167 | + ) -> (InterceptorContext, ConfigBag) { |
| 168 | + let mut ctx = InterceptorContext::new(TypeErasedBox::doesnt_matter()); |
| 169 | + ctx.set_output_or_error(Err(OrchestratorError::other("doesn't matter"))); |
| 170 | + let mut cfg = ConfigBag::base(); |
| 171 | + cfg.set_retry_classifiers(RetryClassifiers::new().with_classifier(AlwaysRetry(error_kind))); |
| 172 | + cfg.put(RequestAttempts::new(current_request_attempts)); |
| 173 | + |
| 174 | + (ctx, cfg) |
| 175 | + } |
| 176 | + |
| 177 | + // Test that error kinds produce the correct "retry after X seconds" output. |
| 178 | + // All error kinds are handled in the same way for the standard strategy. |
| 179 | + fn test_should_retry_error_kind(error_kind: ErrorKind) { |
| 180 | + let (ctx, cfg) = set_up_cfg_and_context(error_kind, 3); |
| 181 | + let strategy = StandardRetryStrategy::default().with_base(|| 1.0); |
| 182 | + let actual = strategy |
| 183 | + .should_attempt_retry(&ctx, &cfg) |
| 184 | + .expect("method is infallible for this use"); |
| 185 | + assert_eq!(ShouldAttempt::YesAfterDelay(Duration::from_secs(4)), actual); |
| 186 | + } |
| 187 | + |
| 188 | + #[test] |
| 189 | + fn should_retry_transient_error_result_after_2s() { |
| 190 | + test_should_retry_error_kind(ErrorKind::TransientError); |
| 191 | + } |
| 192 | + |
| 193 | + #[test] |
| 194 | + fn should_retry_client_error_result_after_2s() { |
| 195 | + test_should_retry_error_kind(ErrorKind::ClientError); |
| 196 | + } |
| 197 | + |
| 198 | + #[test] |
| 199 | + fn should_retry_server_error_result_after_2s() { |
| 200 | + test_should_retry_error_kind(ErrorKind::ServerError); |
| 201 | + } |
| 202 | + |
| 203 | + #[test] |
| 204 | + fn should_retry_throttling_error_result_after_2s() { |
| 205 | + test_should_retry_error_kind(ErrorKind::ThrottlingError); |
| 206 | + } |
| 207 | + |
| 208 | + #[test] |
| 209 | + fn dont_retry_when_out_of_attempts() { |
| 210 | + let current_attempts = 4; |
| 211 | + let max_attempts = current_attempts; |
| 212 | + let (ctx, cfg) = set_up_cfg_and_context(ErrorKind::TransientError, current_attempts); |
| 213 | + let strategy = StandardRetryStrategy::default() |
| 214 | + .with_base(|| 1.0) |
| 215 | + .with_max_attempts(max_attempts); |
| 216 | + let actual = strategy |
| 217 | + .should_attempt_retry(&ctx, &cfg) |
| 218 | + .expect("method is infallible for this use"); |
| 219 | + assert_eq!(ShouldAttempt::No, actual); |
| 220 | + } |
| 221 | +} |
0 commit comments