1
1
use std:: {
2
2
net:: { IpAddr , Ipv4Addr , SocketAddr } ,
3
3
path:: Path ,
4
+ sync:: Arc ,
4
5
} ;
5
6
6
7
use adapter:: client:: Locked ;
8
+ use axum:: {
9
+ extract:: { FromRequest , RequestParts } ,
10
+ http:: StatusCode ,
11
+ middleware, Extension , Router ,
12
+ } ;
7
13
use hyper:: {
8
14
service:: { make_service_fn, service_fn} ,
9
15
Error , Server ,
@@ -14,19 +20,24 @@ use redis::ConnectionInfo;
14
20
use serde:: { Deserialize , Deserializer } ;
15
21
use simple_hyper_server_tls:: { listener_from_pem_files, Protocols , TlsListener } ;
16
22
use slog:: { error, info} ;
23
+ use tower:: ServiceBuilder ;
24
+ use tower_http:: cors:: CorsLayer ;
17
25
18
26
use crate :: {
19
27
db:: { CampaignRemaining , DbPool } ,
20
28
middleware:: {
21
- auth:: Authenticate ,
29
+ auth:: { authenticate , Authenticate } ,
22
30
cors:: { cors, Cors } ,
23
31
Middleware ,
24
32
} ,
25
33
platform:: PlatformApi ,
26
34
response:: { map_response_error, ResponseError } ,
27
35
routes:: {
28
36
get_cfg,
29
- routers:: { analytics_router, campaigns_router, channels_router} ,
37
+ routers:: {
38
+ analytics_router, campaigns_router, campaigns_router_axum, channels_router,
39
+ channels_router_axum,
40
+ } ,
30
41
} ,
31
42
} ;
32
43
use adapter:: Adapter ;
@@ -158,11 +169,45 @@ where
158
169
response. headers_mut ( ) . extend ( headers) ;
159
170
response
160
171
}
172
+
173
+ pub async fn axum_routing ( & self ) -> Router {
174
+ let cors = CorsLayer :: new ( )
175
+ // "GET,HEAD,PUT,PATCH,POST,DELETE"
176
+ . allow_methods ( [
177
+ Method :: GET ,
178
+ Method :: HEAD ,
179
+ Method :: PUT ,
180
+ Method :: PATCH ,
181
+ Method :: POST ,
182
+ Method :: DELETE ,
183
+ ] )
184
+ // allow requests from any origin
185
+ // "*"
186
+ . allow_origin ( tower_http:: cors:: Any ) ;
187
+
188
+ let channels = channels_router_axum :: < C > ( ) ;
189
+
190
+ let campaigns = campaigns_router_axum :: < C > ( ) ;
191
+
192
+ let router = Router :: new ( )
193
+ . nest ( "/channel" , channels)
194
+ . nest ( "/campaign" , campaigns) ;
195
+
196
+ Router :: new ( )
197
+ . nest ( "/v5" , router)
198
+ . layer (
199
+ // keeps the order from top to bottom!
200
+ ServiceBuilder :: new ( )
201
+ . layer ( cors)
202
+ . layer ( middleware:: from_fn ( authenticate :: < C , _ > ) ) ,
203
+ )
204
+ . layer ( Extension ( Arc :: new ( self . clone ( ) ) ) )
205
+ }
161
206
}
162
207
163
208
impl < C : Locked + ' static > Application < C > {
164
209
/// Starts the `hyper` `Server`.
165
- pub async fn run ( self , enable_tls : EnableTls ) {
210
+ pub async fn run2 ( self , enable_tls : EnableTls ) {
166
211
let logger = self . logger . clone ( ) ;
167
212
let socket_addr = match & enable_tls {
168
213
EnableTls :: NoTls ( socket_addr) => socket_addr,
@@ -215,6 +260,29 @@ impl<C: Locked + 'static> Application<C> {
215
260
}
216
261
}
217
262
}
263
+
264
+ pub async fn run ( self , enable_tls : EnableTls ) {
265
+ let logger = self . logger . clone ( ) ;
266
+ let socket_addr = match & enable_tls {
267
+ EnableTls :: NoTls ( socket_addr) => socket_addr,
268
+ EnableTls :: Tls { socket_addr, .. } => socket_addr,
269
+ } ;
270
+
271
+ info ! ( & logger, "Listening on socket address: {}!" , socket_addr) ;
272
+
273
+ let app = self . axum_routing ( ) . await ;
274
+
275
+ let server = axum:: Server :: bind ( socket_addr)
276
+ . serve ( app. into_make_service ( ) )
277
+ . with_graceful_shutdown ( shutdown_signal ( logger. clone ( ) ) ) ;
278
+
279
+ tokio:: pin!( server) ;
280
+
281
+ while let Err ( e) = ( & mut server) . await {
282
+ // This is usually caused by trying to connect on HTTP instead of HTTPS
283
+ error ! ( & logger, "server error: {}" , e; "main" => "run" ) ;
284
+ }
285
+ }
218
286
}
219
287
220
288
impl < C : Locked > Clone for Application < C > {
@@ -278,6 +346,27 @@ pub struct Auth {
278
346
pub chain : primitives:: Chain ,
279
347
}
280
348
349
+ /// A query string deserialized using `serde_qs` instead of axum's `serde_urlencoded`
350
+ pub struct Qs < T > ( pub T ) ;
351
+
352
+ #[ axum:: async_trait]
353
+ impl < T , B > FromRequest < B > for Qs < T >
354
+ where
355
+ T : serde:: de:: DeserializeOwned ,
356
+ B : Send ,
357
+ {
358
+ type Rejection = ( StatusCode , String ) ;
359
+
360
+ async fn from_request ( req : & mut RequestParts < B > ) -> Result < Self , Self :: Rejection > {
361
+ let query = req. uri ( ) . query ( ) . unwrap_or_default ( ) ;
362
+
363
+ match serde_qs:: from_str ( query) {
364
+ Ok ( query) => Ok ( Self ( query) ) ,
365
+ Err ( err) => Err ( ( StatusCode :: BAD_REQUEST , err. to_string ( ) ) ) ,
366
+ }
367
+ }
368
+ }
369
+
281
370
/// A Ctrl+C signal to gracefully shutdown the server
282
371
async fn shutdown_signal ( logger : Logger ) {
283
372
// Wait for the Ctrl+C signal
0 commit comments