Skip to content

Commit d89a90d

Browse files
82marbagDaniele Ahmed
andauthored
Add request ID to response headers (#2438)
* Add request ID to response headers Signed-off-by: Daniele Ahmed <ahmeddan@amazon.de> * Add parsing test Signed-off-by: Daniele Ahmed <ahmeddan@amazon.de> * Style Signed-off-by: Daniele Ahmed <ahmeddan@amazon.de> * CHANGELOG Signed-off-by: Daniele Ahmed <ahmeddan@amazon.de> * Fix import Signed-off-by: Daniele Ahmed <ahmeddan@amazon.de> * Panic if ServerRequestIdProviderLayer is not present Signed-off-by: Daniele Ahmed <ahmeddan@amazon.de> * Own value Signed-off-by: Daniele Ahmed <ahmeddan@amazon.de> * Correct docs Signed-off-by: Daniele Ahmed <ahmeddan@amazon.de> * Add order of layer to expect() message Signed-off-by: Daniele Ahmed <ahmeddan@amazon.de> * Remove Box Signed-off-by: Daniele Ahmed <ahmeddan@amazon.de> * Require order of request ID layers Signed-off-by: Daniele Ahmed <ahmeddan@amazon.de> * Revert "Require order of request ID layers" This reverts commit 147eef2. * One layer to generate and inject the header Signed-off-by: Daniele Ahmed <ahmeddan@amazon.de> * HeaderName for header name Signed-off-by: Daniele Ahmed <ahmeddan@amazon.de> * CHANGELOG Signed-off-by: Daniele Ahmed <ahmeddan@amazon.de> * Remove additional layer Signed-off-by: Daniele Ahmed <ahmeddan@amazon.de> * Remove to_owned Signed-off-by: Daniele Ahmed <ahmeddan@amazon.de> * Add tests, remove unnecessary clone Signed-off-by: Daniele Ahmed <ahmeddan@amazon.de> * take() ResponsePackage instead Signed-off-by: Daniele Ahmed <ahmeddan@amazon.de> * Update docs Signed-off-by: Daniele Ahmed <ahmeddan@amazon.de> * Update docs Signed-off-by: Daniele Ahmed <ahmeddan@amazon.de> * cargo fmt Signed-off-by: Daniele Ahmed <ahmeddan@amazon.de> * Update CHANGELOG Signed-off-by: Daniele Ahmed <ahmeddan@amazon.de> --------- Signed-off-by: Daniele Ahmed <ahmeddan@amazon.de> Co-authored-by: Daniele Ahmed <ahmeddan@amazon.de>
1 parent abbf78f commit d89a90d

File tree

2 files changed

+145
-9
lines changed

2 files changed

+145
-9
lines changed

CHANGELOG.next.toml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,3 +369,15 @@ message = "Increase Tokio version to 1.23.1 for all crates. This is to address [
369369
references = ["smithy-rs#2474"]
370370
meta = { "breaking" = false, "tada" = false, "bug" = false }
371371
author = "rcoh"
372+
373+
[[smithy-rs]]
374+
message = """Servers can send the `ServerRequestId` in the response headers.
375+
Servers need to create their service using the new layer builder `ServerRequestIdProviderLayer::new_with_response_header`:
376+
```
377+
let app = app
378+
.layer(&ServerRequestIdProviderLayer::new_with_response_header(HeaderName::from_static("x-request-id")));
379+
```
380+
"""
381+
references = ["smithy-rs#2438"]
382+
meta = { "breaking" = true, "tada" = false, "bug" = false, "target" = "server"}
383+
author = "82marbag"

rust-runtime/aws-smithy-http-server/src/request/request_id.rs

Lines changed: 133 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
//! A [`ServerRequestId`] is an opaque random identifier generated by the server every time it receives a request.
1313
//! It uniquely identifies the request within that service instance. It can be used to collate all logs, events and
1414
//! data related to a single operation.
15+
//! Use [`ServerRequestIdProviderLayer::new`] to use [`ServerRequestId`] in your handler.
1516
//!
1617
//! The [`ServerRequestId`] can be returned to the caller, who can in turn share the [`ServerRequestId`] to help the service owner in troubleshooting issues related to their usage of the service.
18+
//! Use [`ServerRequestIdProviderLayer::new_with_response_header`] to use [`ServerRequestId`] in your handler and add it to the response headers.
1719
//!
1820
//! The [`ServerRequestId`] is not meant to be propagated to downstream dependencies of the service. You should rely on a distributed tracing implementation for correlation purposes (e.g. OpenTelemetry).
1921
//!
@@ -34,20 +36,24 @@
3436
//! .operation(handler)
3537
//! .build().unwrap();
3638
//!
37-
//! let app = app.layer(&ServerRequestIdProviderLayer::new()); /* Generate a server request ID */
39+
//! let app = app
40+
//! .layer(&ServerRequestIdProviderLayer::new_with_response_header(HeaderName::from_static("x-request-id"))); /* Generate a server request ID and add it to the response header */
3841
//!
3942
//! let bind: std::net::SocketAddr = format!("{}:{}", args.address, args.port)
4043
//! .parse()
4144
//! .expect("unable to parse the server bind address and port");
4245
//! let server = hyper::Server::bind(&bind).serve(app.into_make_service());
4346
//! ```
4447
48+
use std::future::Future;
4549
use std::{
4650
fmt::Display,
4751
task::{Context, Poll},
4852
};
4953

54+
use futures_util::TryFuture;
5055
use http::request::Parts;
56+
use http::{header::HeaderName, HeaderValue, Response};
5157
use thiserror::Error;
5258
use tower::{Layer, Service};
5359
use uuid::Uuid;
@@ -74,6 +80,10 @@ impl ServerRequestId {
7480
pub fn new() -> Self {
7581
Self { id: Uuid::new_v4() }
7682
}
83+
84+
pub(crate) fn to_header(&self) -> HeaderValue {
85+
HeaderValue::from_str(&self.id.to_string()).expect("This string contains only valid ASCII")
86+
}
7787
}
7888

7989
impl Display for ServerRequestId {
@@ -99,17 +109,28 @@ impl Default for ServerRequestId {
99109
#[derive(Clone)]
100110
pub struct ServerRequestIdProvider<S> {
101111
inner: S,
112+
header_key: Option<HeaderName>,
102113
}
103114

104115
/// A layer that provides services with a unique request ID instance
105116
#[derive(Debug)]
106117
#[non_exhaustive]
107-
pub struct ServerRequestIdProviderLayer;
118+
pub struct ServerRequestIdProviderLayer {
119+
header_key: Option<HeaderName>,
120+
}
108121

109122
impl ServerRequestIdProviderLayer {
110-
/// Generate a new unique request ID
123+
/// Generate a new unique request ID and do not add it as a response header
124+
/// Use [`ServerRequestIdProviderLayer::new_with_response_header`] to also add it as a response header
111125
pub fn new() -> Self {
112-
Self {}
126+
Self { header_key: None }
127+
}
128+
129+
/// Generate a new unique request ID and add it as a response header
130+
pub fn new_with_response_header(header_key: HeaderName) -> Self {
131+
Self {
132+
header_key: Some(header_key),
133+
}
113134
}
114135
}
115136

@@ -123,25 +144,47 @@ impl<S> Layer<S> for ServerRequestIdProviderLayer {
123144
type Service = ServerRequestIdProvider<S>;
124145

125146
fn layer(&self, inner: S) -> Self::Service {
126-
ServerRequestIdProvider { inner }
147+
ServerRequestIdProvider {
148+
inner,
149+
header_key: self.header_key.clone(),
150+
}
127151
}
128152
}
129153

130154
impl<Body, S> Service<http::Request<Body>> for ServerRequestIdProvider<S>
131155
where
132-
S: Service<http::Request<Body>>,
156+
S: Service<http::Request<Body>, Response = Response<crate::body::BoxBody>>,
157+
S::Future: std::marker::Send + 'static,
133158
{
134159
type Response = S::Response;
135160
type Error = S::Error;
136-
type Future = S::Future;
161+
type Future = ServerRequestIdResponseFuture<S::Future>;
137162

138163
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
139164
self.inner.poll_ready(cx)
140165
}
141166

142167
fn call(&mut self, mut req: http::Request<Body>) -> Self::Future {
143-
req.extensions_mut().insert(ServerRequestId::new());
144-
self.inner.call(req)
168+
let request_id = ServerRequestId::new();
169+
match &self.header_key {
170+
Some(header_key) => {
171+
req.extensions_mut().insert(request_id.clone());
172+
ServerRequestIdResponseFuture {
173+
response_package: Some(ResponsePackage {
174+
request_id,
175+
header_key: header_key.clone(),
176+
}),
177+
fut: self.inner.call(req),
178+
}
179+
}
180+
None => {
181+
req.extensions_mut().insert(request_id);
182+
ServerRequestIdResponseFuture {
183+
response_package: None,
184+
fut: self.inner.call(req),
185+
}
186+
}
187+
}
145188
}
146189
}
147190

@@ -150,3 +193,84 @@ impl<Protocol> IntoResponse<Protocol> for MissingServerRequestId {
150193
internal_server_error()
151194
}
152195
}
196+
197+
struct ResponsePackage {
198+
request_id: ServerRequestId,
199+
header_key: HeaderName,
200+
}
201+
202+
pin_project_lite::pin_project! {
203+
pub struct ServerRequestIdResponseFuture<Fut> {
204+
response_package: Option<ResponsePackage>,
205+
#[pin]
206+
fut: Fut,
207+
}
208+
}
209+
210+
impl<Fut> Future for ServerRequestIdResponseFuture<Fut>
211+
where
212+
Fut: TryFuture<Ok = Response<crate::body::BoxBody>>,
213+
{
214+
type Output = Result<Fut::Ok, Fut::Error>;
215+
216+
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
217+
let this = self.project();
218+
let fut = this.fut;
219+
let response_package = this.response_package;
220+
fut.try_poll(cx).map_ok(|mut res| {
221+
if let Some(response_package) = response_package.take() {
222+
res.headers_mut()
223+
.insert(response_package.header_key, response_package.request_id.to_header());
224+
}
225+
res
226+
})
227+
}
228+
}
229+
230+
#[cfg(test)]
231+
mod tests {
232+
use super::*;
233+
use crate::body::{Body, BoxBody};
234+
use crate::request::Request;
235+
use http::HeaderValue;
236+
use std::convert::Infallible;
237+
use tower::{service_fn, ServiceBuilder, ServiceExt};
238+
239+
#[test]
240+
fn test_request_id_parsed_by_header_value_infallible() {
241+
ServerRequestId::new().to_header();
242+
}
243+
244+
#[tokio::test]
245+
async fn test_request_id_in_response_header() {
246+
let svc = ServiceBuilder::new()
247+
.layer(&ServerRequestIdProviderLayer::new_with_response_header(
248+
HeaderName::from_static("x-request-id"),
249+
))
250+
.service(service_fn(|_req: Request<Body>| async move {
251+
Ok::<_, Infallible>(Response::new(BoxBody::default()))
252+
}));
253+
254+
let req = Request::new(Body::empty());
255+
256+
let res = svc.oneshot(req).await.unwrap();
257+
let request_id = res.headers().get("x-request-id").unwrap().to_str().unwrap();
258+
259+
assert!(HeaderValue::from_str(request_id).is_ok());
260+
}
261+
262+
#[tokio::test]
263+
async fn test_request_id_not_in_response_header() {
264+
let svc = ServiceBuilder::new()
265+
.layer(&ServerRequestIdProviderLayer::new())
266+
.service(service_fn(|_req: Request<Body>| async move {
267+
Ok::<_, Infallible>(Response::new(BoxBody::default()))
268+
}));
269+
270+
let req = Request::new(Body::empty());
271+
272+
let res = svc.oneshot(req).await.unwrap();
273+
274+
assert!(res.headers().is_empty());
275+
}
276+
}

0 commit comments

Comments
 (0)