Skip to content

Commit cfade88

Browse files
authored
Implement the RecursionDetectionInterceptor for the SDK (#2555)
1 parent 093b65a commit cfade88

File tree

9 files changed

+299
-0
lines changed

9 files changed

+299
-0
lines changed

aws/rust-runtime/aws-http/src/recursion_detection.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ use http::HeaderValue;
1111
use percent_encoding::{percent_encode, CONTROLS};
1212
use std::borrow::Cow;
1313

14+
// TODO(enableNewSmithyRuntime): Delete this module
15+
1416
/// Recursion Detection Middleware
1517
///
1618
/// This middleware inspects the value of the `AWS_LAMBDA_FUNCTION_NAME` and `_X_AMZN_TRACE_ID` environment

aws/rust-runtime/aws-http/src/user_agent.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,8 @@ impl fmt::Display for ExecEnvMetadata {
513513
}
514514
}
515515

516+
// TODO(enableNewSmithyRuntime): Delete the user agent Tower middleware and consider moving all the remaining code into aws-runtime
517+
516518
/// User agent middleware
517519
#[non_exhaustive]
518520
#[derive(Default, Clone, Debug)]

aws/rust-runtime/aws-runtime/Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,14 @@ aws-smithy-runtime-api = { path = "../../../rust-runtime/aws-smithy-runtime-api"
1616
aws-smithy-types = { path = "../../../rust-runtime/aws-smithy-types" }
1717
aws-types = { path = "../aws-types" }
1818
http = "0.2.3"
19+
percent-encoding = "2.1.0"
1920
tracing = "0.1"
2021

2122
[dev-dependencies]
23+
aws-smithy-protocol-test = { path = "../../../rust-runtime/aws-smithy-protocol-test" }
24+
proptest = "1"
25+
serde = { version = "1", features = ["derive"]}
26+
serde_json = "1"
2227
tracing-test = "0.2.1"
2328

2429
[package.metadata.docs.rs]

aws/rust-runtime/aws-runtime/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,8 @@ pub mod auth;
1919
/// Supporting code for identity in the AWS SDK.
2020
pub mod identity;
2121

22+
/// Supporting code for recursion detection in the AWS SDK.
23+
pub mod recursion_detection;
24+
2225
/// Supporting code for user agent headers in the AWS SDK.
2326
pub mod user_agent;
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
use aws_smithy_runtime_api::client::interceptors::{BoxError, Interceptor, InterceptorContext};
7+
use aws_smithy_runtime_api::client::orchestrator::{HttpRequest, HttpResponse};
8+
use aws_smithy_runtime_api::config_bag::ConfigBag;
9+
use aws_types::os_shim_internal::Env;
10+
use http::HeaderValue;
11+
use percent_encoding::{percent_encode, CONTROLS};
12+
use std::borrow::Cow;
13+
14+
const TRACE_ID_HEADER: &str = "x-amzn-trace-id";
15+
16+
mod env {
17+
pub(super) const LAMBDA_FUNCTION_NAME: &str = "AWS_LAMBDA_FUNCTION_NAME";
18+
pub(super) const TRACE_ID: &str = "_X_AMZN_TRACE_ID";
19+
}
20+
21+
/// Recursion Detection Interceptor
22+
///
23+
/// This interceptor inspects the value of the `AWS_LAMBDA_FUNCTION_NAME` and `_X_AMZN_TRACE_ID` environment
24+
/// variables to detect if the request is being invoked in a Lambda function. If it is, the `X-Amzn-Trace-Id` header
25+
/// will be set. This enables downstream services to prevent accidentally infinitely recursive invocations spawned
26+
/// from Lambda.
27+
#[non_exhaustive]
28+
#[derive(Debug, Default)]
29+
pub struct RecursionDetectionInterceptor {
30+
env: Env,
31+
}
32+
33+
impl RecursionDetectionInterceptor {
34+
/// Creates a new `RecursionDetectionInterceptor`
35+
pub fn new() -> Self {
36+
Self::default()
37+
}
38+
}
39+
40+
impl Interceptor<HttpRequest, HttpResponse> for RecursionDetectionInterceptor {
41+
fn modify_before_signing(
42+
&self,
43+
context: &mut InterceptorContext<HttpRequest, HttpResponse>,
44+
_cfg: &mut ConfigBag,
45+
) -> Result<(), BoxError> {
46+
let request = context.request_mut()?;
47+
if request.headers().contains_key(TRACE_ID_HEADER) {
48+
return Ok(());
49+
}
50+
51+
if let (Ok(_function_name), Ok(trace_id)) = (
52+
self.env.get(env::LAMBDA_FUNCTION_NAME),
53+
self.env.get(env::TRACE_ID),
54+
) {
55+
request
56+
.headers_mut()
57+
.insert(TRACE_ID_HEADER, encode_header(trace_id.as_bytes()));
58+
}
59+
Ok(())
60+
}
61+
}
62+
63+
/// Encodes a byte slice as a header.
64+
///
65+
/// ASCII control characters are percent encoded which ensures that all byte sequences are valid headers
66+
fn encode_header(value: &[u8]) -> HeaderValue {
67+
let value: Cow<'_, str> = percent_encode(value, CONTROLS).into();
68+
HeaderValue::from_bytes(value.as_bytes()).expect("header is encoded, header must be valid")
69+
}
70+
71+
#[cfg(test)]
72+
mod tests {
73+
use super::*;
74+
use aws_smithy_http::body::SdkBody;
75+
use aws_smithy_protocol_test::{assert_ok, validate_headers};
76+
use aws_smithy_runtime_api::type_erasure::TypedBox;
77+
use aws_types::os_shim_internal::Env;
78+
use http::HeaderValue;
79+
use proptest::{prelude::*, proptest};
80+
use serde::Deserialize;
81+
use std::collections::HashMap;
82+
83+
proptest! {
84+
#[test]
85+
fn header_encoding_never_panics(s in any::<Vec<u8>>()) {
86+
encode_header(&s);
87+
}
88+
}
89+
90+
#[test]
91+
fn every_char() {
92+
let buff = (0..=255).collect::<Vec<u8>>();
93+
assert_eq!(
94+
encode_header(&buff),
95+
HeaderValue::from_static(
96+
r##"%00%01%02%03%04%05%06%07%08%09%0A%0B%0C%0D%0E%0F%10%11%12%13%14%15%16%17%18%19%1A%1B%1C%1D%1E%1F !"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}~%7F%80%81%82%83%84%85%86%87%88%89%8A%8B%8C%8D%8E%8F%90%91%92%93%94%95%96%97%98%99%9A%9B%9C%9D%9E%9F%A0%A1%A2%A3%A4%A5%A6%A7%A8%A9%AA%AB%AC%AD%AE%AF%B0%B1%B2%B3%B4%B5%B6%B7%B8%B9%BA%BB%BC%BD%BE%BF%C0%C1%C2%C3%C4%C5%C6%C7%C8%C9%CA%CB%CC%CD%CE%CF%D0%D1%D2%D3%D4%D5%D6%D7%D8%D9%DA%DB%DC%DD%DE%DF%E0%E1%E2%E3%E4%E5%E6%E7%E8%E9%EA%EB%EC%ED%EE%EF%F0%F1%F2%F3%F4%F5%F6%F7%F8%F9%FA%FB%FC%FD%FE%FF"##
97+
)
98+
);
99+
}
100+
101+
#[test]
102+
fn run_tests() {
103+
let test_cases: Vec<TestCase> =
104+
serde_json::from_str(include_str!("../test-data/recursion-detection.json"))
105+
.expect("invalid test case");
106+
for test_case in test_cases {
107+
check(test_case)
108+
}
109+
}
110+
111+
#[derive(Deserialize)]
112+
#[serde(rename_all = "camelCase")]
113+
struct TestCase {
114+
env: HashMap<String, String>,
115+
request_headers_before: Vec<String>,
116+
request_headers_after: Vec<String>,
117+
}
118+
119+
impl TestCase {
120+
fn env(&self) -> Env {
121+
Env::from(self.env.clone())
122+
}
123+
124+
/// Headers on the input request
125+
fn request_headers_before(&self) -> impl Iterator<Item = (&str, &str)> {
126+
Self::split_headers(&self.request_headers_before)
127+
}
128+
129+
/// Headers on the output request
130+
fn request_headers_after(&self) -> impl Iterator<Item = (&str, &str)> {
131+
Self::split_headers(&self.request_headers_after)
132+
}
133+
134+
/// Split text headers on `: `
135+
fn split_headers(headers: &[String]) -> impl Iterator<Item = (&str, &str)> {
136+
headers
137+
.iter()
138+
.map(|header| header.split_once(": ").expect("header must contain :"))
139+
}
140+
}
141+
142+
fn check(test_case: TestCase) {
143+
let env = test_case.env();
144+
let mut request = http::Request::builder();
145+
for (name, value) in test_case.request_headers_before() {
146+
request = request.header(name, value);
147+
}
148+
let request = request.body(SdkBody::empty()).expect("must be valid");
149+
let mut context = InterceptorContext::new(TypedBox::new("doesntmatter").erase());
150+
context.set_request(request);
151+
let mut config = ConfigBag::base();
152+
153+
RecursionDetectionInterceptor { env }
154+
.modify_before_signing(&mut context, &mut config)
155+
.expect("interceptor must succeed");
156+
let mutated_request = context.request().expect("request is still set");
157+
for name in mutated_request.headers().keys() {
158+
assert_eq!(
159+
mutated_request.headers().get_all(name).iter().count(),
160+
1,
161+
"No duplicated headers"
162+
)
163+
}
164+
assert_ok(validate_headers(
165+
mutated_request.headers(),
166+
test_case.request_headers_after(),
167+
))
168+
}
169+
}
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
[
2+
{
3+
"env": {},
4+
"requestHeadersBefore": [],
5+
"requestHeadersAfter": [],
6+
"description": [
7+
"The AWS_LAMBDA_FUNCTION_NAME and _X_AMZN_TRACE_ID environment variables are not set.",
8+
"There should be no X-Amzn-Trace-Id header sent."
9+
]
10+
},
11+
{
12+
"env": {
13+
"AWS_LAMBDA_FUNCTION_NAME": "some-function",
14+
"_X_AMZN_TRACE_ID": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1;lineage=a87bd80c:0,68fd508a:5,c512fbe3:2"
15+
},
16+
"requestHeadersBefore": [],
17+
"requestHeadersAfter": [
18+
"X-Amzn-Trace-Id: Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1;lineage=a87bd80c:0,68fd508a:5,c512fbe3:2"
19+
],
20+
"description": [
21+
"AWS_LAMBDA_FUNCTION_NAME is set, and",
22+
"_X_AMZN_TRACE_ID is set to \"Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1;lineage=a87bd80c:0,68fd508a:5,c512fbe3:2\".",
23+
"The X-Amzn-Trace-Id header should be sent with that value."
24+
]
25+
},
26+
{
27+
"env": {
28+
"_X_AMZN_TRACE_ID": "Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1;lineage=a87bd80c:0,68fd508a:5,c512fbe3:2"
29+
},
30+
"requestHeadersBefore": [],
31+
"requestHeadersAfter": [],
32+
"description": [
33+
"AWS_LAMBDA_FUNCTION_NAME is NOT set, and",
34+
"_X_AMZN_TRACE_ID is set to \"Root=1-5759e988-bd862e3fe1be46a994272793;Parent=53995c3f42cd8ad8;Sampled=1;lineage=a87bd80c:0,68fd508a:5,c512fbe3:2\".",
35+
"The X-Amzn-Trace-Id header should NOT be sent with that value."
36+
]
37+
},
38+
{
39+
"env": {
40+
"AWS_LAMBDA_FUNCTION_NAME": "some-function",
41+
"_X_AMZN_TRACE_ID": "EnvValue"
42+
},
43+
"requestHeadersBefore": [
44+
"X-Amzn-Trace-Id: OriginalValue"
45+
],
46+
"requestHeadersAfter": [
47+
"X-Amzn-Trace-Id: OriginalValue"
48+
],
49+
"desciption": [
50+
"AWS_LAMBDA_FUNCTION_NAME is set, and",
51+
"_X_AMZN_TRACE_ID is set to \"EnvValue\",",
52+
"but the X-Amzn-Trace-Id header is already set on the request.",
53+
"The X-Amzn-Trace-Id header should keep its original value."
54+
]
55+
},
56+
{
57+
"env": {
58+
"AWS_LAMBDA_FUNCTION_NAME": "some-function",
59+
"_X_AMZN_TRACE_ID": "first\nsecond¼\t"
60+
},
61+
"requestHeadersBefore": [],
62+
"requestHeadersAfter": [
63+
"X-Amzn-Trace-Id: first%0Asecond%C2%BC%09"
64+
],
65+
"description": [
66+
"AWS_LAMBDA_FUNCTION_NAME is set, and",
67+
"_X_AMZN_TRACE_ID has ASCII control characters in it.",
68+
"The X-Amzn-Trace-Id header is added with the control characters percent encoded."
69+
]
70+
}
71+
]

aws/sdk-codegen/src/main/kotlin/software/amazon/smithy/rustsdk/AwsCodegenDecorator.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ val DECORATORS: List<ClientCodegenDecorator> = listOf(
5151
OperationInputTestDecorator(),
5252
AwsRequestIdDecorator(),
5353
DisabledAuthDecorator(),
54+
RecursionDetectionDecorator(),
5455
),
5556

5657
// Service specific decorators
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package software.amazon.smithy.rustsdk
7+
8+
import software.amazon.smithy.rust.codegen.client.smithy.ClientCodegenContext
9+
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
10+
import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceRuntimePluginCustomization
11+
import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceRuntimePluginSection
12+
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
13+
import software.amazon.smithy.rust.codegen.core.rustlang.rust
14+
import software.amazon.smithy.rust.codegen.core.rustlang.writable
15+
import software.amazon.smithy.rust.codegen.core.util.letIf
16+
17+
class RecursionDetectionDecorator : ClientCodegenDecorator {
18+
override val name: String get() = "RecursionDetectionDecorator"
19+
override val order: Byte get() = 0
20+
21+
override fun serviceRuntimePluginCustomizations(
22+
codegenContext: ClientCodegenContext,
23+
baseCustomizations: List<ServiceRuntimePluginCustomization>,
24+
): List<ServiceRuntimePluginCustomization> =
25+
baseCustomizations.letIf(codegenContext.settings.codegenConfig.enableNewSmithyRuntime) {
26+
it + listOf(RecursionDetectionRuntimePluginCustomization(codegenContext))
27+
}
28+
}
29+
30+
private class RecursionDetectionRuntimePluginCustomization(
31+
private val codegenContext: ClientCodegenContext,
32+
) : ServiceRuntimePluginCustomization() {
33+
override fun section(section: ServiceRuntimePluginSection): Writable = writable {
34+
if (section is ServiceRuntimePluginSection.AdditionalConfig) {
35+
section.registerInterceptor(codegenContext.runtimeConfig, this) {
36+
rust(
37+
"#T::new()",
38+
AwsRuntimeType.awsRuntime(codegenContext.runtimeConfig)
39+
.resolve("recursion_detection::RecursionDetectionInterceptor"),
40+
)
41+
}
42+
}
43+
}
44+
}

aws/sra-test/integration-tests/aws-sdk-s3/tests/sra_test.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use aws_credential_types::cache::{CredentialsCache, SharedCredentialsCache};
77
use aws_credential_types::provider::SharedCredentialsProvider;
88
use aws_http::user_agent::{ApiMetadata, AwsUserAgent};
99
use aws_runtime::auth::sigv4::SigV4OperationSigningConfig;
10+
use aws_runtime::recursion_detection::RecursionDetectionInterceptor;
1011
use aws_runtime::user_agent::UserAgentInterceptor;
1112
use aws_sdk_s3::config::{Credentials, Region};
1213
use aws_sdk_s3::operation::list_objects_v2::{
@@ -155,6 +156,7 @@ async fn sra_manual_test() {
155156
cfg.get::<Interceptors<HttpRequest, HttpResponse>>()
156157
.expect("interceptors set")
157158
.register_client_interceptor(Arc::new(UserAgentInterceptor::new()) as _)
159+
.register_client_interceptor(Arc::new(RecursionDetectionInterceptor::new()) as _)
158160
.register_client_interceptor(Arc::new(OverrideSigningTimeInterceptor) as _);
159161
Ok(())
160162
}

0 commit comments

Comments
 (0)