Skip to content

Commit f0e1e60

Browse files
committed
sentry - axum refactor:
- routes - campaign - create - routes - campaign - event submissions - routes - channel - list - routes - channel - dummy deposit - routers - if dummy adapter middleware - middleware - auth - authenticate & authentication_required - middleware - campaign - load - response - impl IntoResponse for ResponseError - application - Qs - query string extractor based on serde_qs
1 parent 3a52a0b commit f0e1e60

File tree

8 files changed

+618
-56
lines changed

8 files changed

+618
-56
lines changed

sentry/src/application.rs

Lines changed: 92 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
use std::{
22
net::{IpAddr, Ipv4Addr, SocketAddr},
33
path::Path,
4+
sync::Arc,
45
};
56

67
use adapter::client::Locked;
8+
use axum::{
9+
extract::{FromRequest, RequestParts},
10+
http::StatusCode,
11+
middleware, Extension, Router,
12+
};
713
use hyper::{
814
service::{make_service_fn, service_fn},
915
Error, Server,
@@ -14,19 +20,24 @@ use redis::ConnectionInfo;
1420
use serde::{Deserialize, Deserializer};
1521
use simple_hyper_server_tls::{listener_from_pem_files, Protocols, TlsListener};
1622
use slog::{error, info};
23+
use tower::ServiceBuilder;
24+
use tower_http::cors::CorsLayer;
1725

1826
use crate::{
1927
db::{CampaignRemaining, DbPool},
2028
middleware::{
21-
auth::Authenticate,
29+
auth::{authenticate, Authenticate},
2230
cors::{cors, Cors},
2331
Middleware,
2432
},
2533
platform::PlatformApi,
2634
response::{map_response_error, ResponseError},
2735
routes::{
2836
get_cfg,
29-
routers::{analytics_router, campaigns_router, channels_router},
37+
routers::{
38+
analytics_router, campaigns_router, campaigns_router_axum, channels_router,
39+
channels_router_axum,
40+
},
3041
},
3142
};
3243
use adapter::Adapter;
@@ -158,11 +169,45 @@ where
158169
response.headers_mut().extend(headers);
159170
response
160171
}
172+
173+
pub async fn axum_routing(&self) -> Router {
174+
let cors = CorsLayer::new()
175+
// "GET,HEAD,PUT,PATCH,POST,DELETE"
176+
.allow_methods([
177+
Method::GET,
178+
Method::HEAD,
179+
Method::PUT,
180+
Method::PATCH,
181+
Method::POST,
182+
Method::DELETE,
183+
])
184+
// allow requests from any origin
185+
// "*"
186+
.allow_origin(tower_http::cors::Any);
187+
188+
let channels = channels_router_axum::<C>();
189+
190+
let campaigns = campaigns_router_axum::<C>();
191+
192+
let router = Router::new()
193+
.nest("/channel", channels)
194+
.nest("/campaign", campaigns);
195+
196+
Router::new()
197+
.nest("/v5", router)
198+
.layer(
199+
// keeps the order from top to bottom!
200+
ServiceBuilder::new()
201+
.layer(cors)
202+
.layer(middleware::from_fn(authenticate::<C, _>)),
203+
)
204+
.layer(Extension(Arc::new(self.clone())))
205+
}
161206
}
162207

163208
impl<C: Locked + 'static> Application<C> {
164209
/// Starts the `hyper` `Server`.
165-
pub async fn run(self, enable_tls: EnableTls) {
210+
pub async fn run2(self, enable_tls: EnableTls) {
166211
let logger = self.logger.clone();
167212
let socket_addr = match &enable_tls {
168213
EnableTls::NoTls(socket_addr) => socket_addr,
@@ -215,6 +260,29 @@ impl<C: Locked + 'static> Application<C> {
215260
}
216261
}
217262
}
263+
264+
pub async fn run(self, enable_tls: EnableTls) {
265+
let logger = self.logger.clone();
266+
let socket_addr = match &enable_tls {
267+
EnableTls::NoTls(socket_addr) => socket_addr,
268+
EnableTls::Tls { socket_addr, .. } => socket_addr,
269+
};
270+
271+
info!(&logger, "Listening on socket address: {}!", socket_addr);
272+
273+
let app = self.axum_routing().await;
274+
275+
let server = axum::Server::bind(socket_addr)
276+
.serve(app.into_make_service())
277+
.with_graceful_shutdown(shutdown_signal(logger.clone()));
278+
279+
tokio::pin!(server);
280+
281+
while let Err(e) = (&mut server).await {
282+
// This is usually caused by trying to connect on HTTP instead of HTTPS
283+
error!(&logger, "server error: {}", e; "main" => "run");
284+
}
285+
}
218286
}
219287

220288
impl<C: Locked> Clone for Application<C> {
@@ -278,6 +346,27 @@ pub struct Auth {
278346
pub chain: primitives::Chain,
279347
}
280348

349+
/// A query string deserialized using `serde_qs` instead of axum's `serde_urlencoded`
350+
pub struct Qs<T>(pub T);
351+
352+
#[axum::async_trait]
353+
impl<T, B> FromRequest<B> for Qs<T>
354+
where
355+
T: serde::de::DeserializeOwned,
356+
B: Send,
357+
{
358+
type Rejection = (StatusCode, String);
359+
360+
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
361+
let query = req.uri().query().unwrap_or_default();
362+
363+
match serde_qs::from_str(query) {
364+
Ok(query) => Ok(Self(query)),
365+
Err(err) => Err((StatusCode::BAD_REQUEST, err.to_string())),
366+
}
367+
}
368+
}
369+
281370
/// A Ctrl+C signal to gracefully shutdown the server
282371
async fn shutdown_signal(logger: Logger) {
283372
// Wait for the Ctrl+C signal

sentry/src/middleware/auth.rs

Lines changed: 93 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
use std::error;
1+
use std::{error, sync::Arc};
22

33
use async_trait::async_trait;
4-
use hyper::header::{AUTHORIZATION, REFERER};
5-
use hyper::{Body, Request};
4+
use axum::middleware::Next;
5+
use hyper::{
6+
header::{AUTHORIZATION, REFERER},
7+
Body, Request,
8+
};
69
use redis::aio::MultiplexedConnection;
710

811
use adapter::{prelude::*, primitives::Session as AdapterSession, Adapter};
@@ -71,6 +74,92 @@ impl<C: Locked + 'static> Middleware<C> for IsAdmin {
7174
}
7275
}
7376

77+
pub async fn authentication_required<C: Locked + 'static, B>(
78+
request: axum::http::Request<B>,
79+
next: Next<B>,
80+
) -> Result<axum::response::Response, ResponseError> {
81+
if request.extensions().get::<Auth>().is_some() {
82+
Ok(next.run(request).await)
83+
} else {
84+
Err(ResponseError::Unauthorized)
85+
}
86+
}
87+
88+
pub async fn authenticate<C: Locked + 'static, B>(
89+
mut request: axum::http::Request<B>,
90+
next: Next<B>,
91+
) -> Result<axum::response::Response, ResponseError> {
92+
let (adapter, redis) = {
93+
let app = request
94+
.extensions()
95+
.get::<Arc<Application<C>>>()
96+
.expect("Application should always be present");
97+
98+
(app.adapter.clone(), app.redis.clone())
99+
};
100+
101+
let referrer = request
102+
.headers()
103+
.get(REFERER)
104+
.and_then(|hv| hv.to_str().ok().map(ToString::to_string));
105+
106+
let session = Session {
107+
ip: get_request_ip(&request),
108+
country: None,
109+
referrer_header: referrer,
110+
os: None,
111+
};
112+
request.extensions_mut().insert(session);
113+
114+
let authorization = request.headers().get(AUTHORIZATION);
115+
116+
let prefix = "Bearer ";
117+
118+
let token = authorization
119+
.and_then(|hv| {
120+
hv.to_str()
121+
.map(|token_str| token_str.strip_prefix(prefix))
122+
.transpose()
123+
})
124+
.transpose()?;
125+
126+
if let Some(token) = token {
127+
let adapter_session = match redis::cmd("GET")
128+
.arg(token)
129+
.query_async::<_, Option<String>>(&mut redis.clone())
130+
.await?
131+
.and_then(|session_str| serde_json::from_str::<AdapterSession>(&session_str).ok())
132+
{
133+
Some(adapter_session) => adapter_session,
134+
None => {
135+
// If there was a problem with the Session or the Token, this will error
136+
// and a BadRequest response will be returned
137+
let adapter_session = adapter.session_from_token(token).await?;
138+
139+
// save the Adapter Session to Redis for the next request
140+
// if serde errors on deserialization this will override the value inside
141+
redis::cmd("SET")
142+
.arg(token)
143+
.arg(serde_json::to_string(&adapter_session)?)
144+
.query_async(&mut redis.clone())
145+
.await?;
146+
147+
adapter_session
148+
}
149+
};
150+
151+
let auth = Auth {
152+
era: adapter_session.era,
153+
uid: ValidatorId::from(adapter_session.uid),
154+
chain: adapter_session.chain,
155+
};
156+
157+
request.extensions_mut().insert(auth);
158+
}
159+
160+
Ok(next.run(request).await)
161+
}
162+
74163
/// Check `Authorization` header for `Bearer` scheme with `Adapter::session_from_token`.
75164
/// If the `Adapter` fails to create an `AdapterSession`, `ResponseError::BadRequest` will be returned.
76165
async fn for_request<C: Locked>(
@@ -140,7 +229,7 @@ async fn for_request<C: Locked>(
140229
Ok(req)
141230
}
142231

143-
fn get_request_ip(req: &Request<Body>) -> Option<String> {
232+
fn get_request_ip<B>(req: &Request<B>) -> Option<String> {
144233
req.headers()
145234
.get("true-client-ip")
146235
.or_else(|| req.headers().get("x-forwarded-for"))

0 commit comments

Comments
 (0)