@@ -4,27 +4,25 @@ use std::{
4
4
sync:: Arc ,
5
5
} ;
6
6
7
- use adapter:: client:: Locked ;
8
7
use axum:: {
9
8
extract:: { FromRequest , RequestParts } ,
10
9
http:: StatusCode ,
11
10
middleware,
12
11
routing:: get,
13
12
Extension , Router ,
14
13
} ;
15
- use hyper:: {
16
- service:: { make_service_fn, service_fn} ,
17
- Error , Server ,
18
- } ;
14
+ use axum_server:: { tls_rustls:: RustlsConfig , Handle } ;
15
+ use hyper:: { Body , Method , Request , Response } ;
19
16
use once_cell:: sync:: Lazy ;
20
- use primitives:: { config:: Environment , ValidatorId } ;
21
- use redis:: ConnectionInfo ;
17
+ use redis:: { aio:: MultiplexedConnection , ConnectionInfo } ;
22
18
use serde:: { Deserialize , Deserializer } ;
23
- use simple_hyper_server_tls:: { listener_from_pem_files, Protocols , TlsListener } ;
24
- use slog:: { error, info} ;
19
+ use slog:: { error, info, Logger } ;
25
20
use tower:: ServiceBuilder ;
26
21
use tower_http:: cors:: CorsLayer ;
27
22
23
+ use adapter:: { client:: Locked , Adapter } ;
24
+ use primitives:: { config:: Environment , ValidatorId } ;
25
+
28
26
use crate :: {
29
27
db:: { CampaignRemaining , DbPool } ,
30
28
middleware:: {
@@ -42,10 +40,6 @@ use crate::{
42
40
} ,
43
41
} ,
44
42
} ;
45
- use adapter:: Adapter ;
46
- use hyper:: { Body , Method , Request , Response } ;
47
- use redis:: aio:: MultiplexedConnection ;
48
- use slog:: Logger ;
49
43
50
44
/// an error used when deserializing a [`EnvConfig`] instance from environment variables
51
45
/// see [`EnvConfig::from_env()`]
@@ -206,81 +200,50 @@ where
206
200
}
207
201
208
202
impl < C : Locked + ' static > Application < C > {
209
- /// Starts the `hyper` `Server`.
210
- pub async fn run2 ( self , enable_tls : EnableTls ) {
203
+ pub async fn run ( self , enable_tls : EnableTls ) {
211
204
let logger = self . logger . clone ( ) ;
212
205
let socket_addr = match & enable_tls {
213
206
EnableTls :: NoTls ( socket_addr) => socket_addr,
214
207
EnableTls :: Tls { socket_addr, .. } => socket_addr,
215
208
} ;
216
209
217
210
info ! ( & logger, "Listening on socket address: {}!" , socket_addr) ;
211
+ let router = self . axum_routing ( ) . await ;
212
+
213
+ let handle = Handle :: new ( ) ;
214
+
215
+ // Spawn a task to shutdown server.
216
+ tokio:: spawn ( shutdown_signal ( logger. clone ( ) , handle. clone ( ) ) ) ;
218
217
219
218
match enable_tls {
220
219
EnableTls :: NoTls ( socket_addr) => {
221
- let make_service = make_service_fn ( |_| {
222
- let server = self . clone ( ) ;
223
- async move {
224
- Ok :: < _ , Error > ( service_fn ( move |req| {
225
- let server = server. clone ( ) ;
226
- async move { Ok :: < _ , Error > ( server. handle_routing ( req) . await ) }
227
- } ) )
228
- }
229
- } ) ;
230
-
231
- let server = Server :: bind ( & socket_addr)
232
- . serve ( make_service)
233
- . with_graceful_shutdown ( shutdown_signal ( logger. clone ( ) ) ) ;
234
-
235
- if let Err ( e) = server. await {
236
- error ! ( & logger, "server error: {}" , e; "main" => "run" ) ;
237
- }
238
- }
239
- EnableTls :: Tls { listener, .. } => {
240
- let make_service = make_service_fn ( |_| {
241
- let server = self . clone ( ) ;
242
- async move {
243
- Ok :: < _ , Error > ( service_fn ( move |req| {
244
- let server = server. clone ( ) ;
245
- async move { Ok :: < _ , Error > ( server. handle_routing ( req) . await ) }
246
- } ) )
247
- }
248
- } ) ;
249
-
250
- // TODO: Find a way to redirect to HTTPS
251
- let server = Server :: builder ( listener)
252
- . serve ( make_service)
253
- . with_graceful_shutdown ( shutdown_signal ( logger. clone ( ) ) ) ;
220
+ let server = axum_server:: bind ( socket_addr)
221
+ . handle ( handle)
222
+ . serve ( router. into_make_service ( ) ) ;
223
+
254
224
tokio:: pin!( server) ;
255
225
256
226
while let Err ( e) = ( & mut server) . await {
257
227
// This is usually caused by trying to connect on HTTP instead of HTTPS
258
228
error ! ( & logger, "server error: {}" , e; "main" => "run" ) ;
259
229
}
260
230
}
261
- }
262
- }
263
231
264
- pub async fn run ( self , enable_tls : EnableTls ) {
265
- let logger = self . logger . clone ( ) ;
266
- let socket_addr = match & enable_tls {
267
- EnableTls :: NoTls ( socket_addr) => socket_addr,
268
- EnableTls :: Tls { socket_addr, .. } => socket_addr,
269
- } ;
270
-
271
- info ! ( & logger, "Listening on socket address: {}!" , socket_addr) ;
272
-
273
- let app = self . axum_routing ( ) . await ;
274
-
275
- let server = axum:: Server :: bind ( socket_addr)
276
- . serve ( app. into_make_service ( ) )
277
- . with_graceful_shutdown ( shutdown_signal ( logger. clone ( ) ) ) ;
232
+ EnableTls :: Tls {
233
+ config,
234
+ socket_addr,
235
+ } => {
236
+ let server = axum_server:: bind_rustls ( socket_addr, config)
237
+ . handle ( handle)
238
+ . serve ( router. into_make_service ( ) ) ;
278
239
279
- tokio:: pin!( server) ;
240
+ tokio:: pin!( server) ;
280
241
281
- while let Err ( e) = ( & mut server) . await {
282
- // This is usually caused by trying to connect on HTTP instead of HTTPS
283
- error ! ( & logger, "server error: {}" , e; "main" => "run" ) ;
242
+ while let Err ( e) = ( & mut server) . await {
243
+ // This is usually caused by trying to connect on HTTP instead of HTTPS
244
+ error ! ( & logger, "server error: {}" , e; "main" => "run" ) ;
245
+ }
246
+ }
284
247
}
285
248
}
286
249
}
@@ -304,21 +267,20 @@ pub enum EnableTls {
304
267
NoTls ( SocketAddr ) ,
305
268
Tls {
306
269
socket_addr : SocketAddr ,
307
- listener : TlsListener ,
270
+ config : RustlsConfig ,
308
271
} ,
309
272
}
310
273
311
274
impl EnableTls {
312
- pub fn new_tls < C : AsRef < Path > , K : AsRef < Path > > (
275
+ pub async fn new_tls < C : AsRef < Path > , K : AsRef < Path > > (
313
276
certificates : C ,
314
277
private_keys : K ,
315
278
socket_addr : SocketAddr ,
316
279
) -> Result < Self , Box < dyn std:: error:: Error > > {
317
- let listener =
318
- listener_from_pem_files ( certificates, private_keys, Protocols :: ALL , & socket_addr) ?;
280
+ let config = RustlsConfig :: from_pem_file ( certificates, private_keys) . await ?;
319
281
320
282
Ok ( Self :: Tls {
321
- listener ,
283
+ config ,
322
284
socket_addr,
323
285
} )
324
286
}
@@ -368,12 +330,15 @@ where
368
330
}
369
331
370
332
/// A Ctrl+C signal to gracefully shutdown the server
371
- async fn shutdown_signal ( logger : Logger ) {
333
+ async fn shutdown_signal ( logger : Logger , handle : Handle ) {
372
334
// Wait for the Ctrl+C signal
373
335
tokio:: signal:: ctrl_c ( )
374
336
. await
375
337
. expect ( "failed to install CTRL+C signal handler" ) ;
376
338
339
+ // Signal the server to shutdown using Handle.
340
+ handle. shutdown ( ) ;
341
+
377
342
info ! ( & logger, "Received Ctrl+C signal. Shutting down.." )
378
343
}
379
344
0 commit comments