diff --git a/Cargo.lock b/Cargo.lock index 3d491c4..ea80168 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -362,6 +362,7 @@ dependencies = [ "opentelemetry_sdk", "prost", "serde", + "tokio", "tonic", "tonic-build", "tracing", diff --git a/src/common/Cargo.toml b/src/common/Cargo.toml index e6957fa..d247f22 100644 --- a/src/common/Cargo.toml +++ b/src/common/Cargo.toml @@ -48,6 +48,9 @@ features = ["derive"] version = "0.3.18" features = ["env-filter"] +[dependencies.tokio] +version = "1.37.0" + [dependencies.tonic] version = "0.11.0" diff --git a/src/common/src/kill_signals.rs b/src/common/src/kill_signals.rs new file mode 100644 index 0000000..1f85006 --- /dev/null +++ b/src/common/src/kill_signals.rs @@ -0,0 +1,30 @@ +use tokio::signal; +use tracing::info; + +pub async fn wait_for_kill_signals() { + let ctrl_c = async { + signal::ctrl_c() + .await + .expect("Failed to install Ctrl+C handler"); + }; + + #[cfg(unix)] + let terminate = async { + signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("Failed to install signal handler") + .recv() + .await; + }; + + #[cfg(not(unix))] + let terminate = std::future::pending::<()>(); + + tokio::select! { + _ = ctrl_c => { + info!("Received ctrl_c!"); + }, + _ = terminate => { + info!("Received terminate!"); + }, + } +} diff --git a/src/common/src/lib.rs b/src/common/src/lib.rs index 4598395..1f791e3 100644 --- a/src/common/src/lib.rs +++ b/src/common/src/lib.rs @@ -1,3 +1,4 @@ pub mod grpc; +pub mod kill_signals; pub mod loggers; pub mod options; diff --git a/src/gpt_answer_server/src/main.rs b/src/gpt_answer_server/src/main.rs index 7f6034b..073c69e 100644 --- a/src/gpt_answer_server/src/main.rs +++ b/src/gpt_answer_server/src/main.rs @@ -1,21 +1,29 @@ use clap::{Parser, Subcommand}; use opentelemetry::global; +use tokio::signal; +use tokio::sync::oneshot; +use tokio::sync::oneshot::Receiver; use tonic::transport::Server; +use tracing::info; use common::grpc::gpt_answer::gpt_answer::gpt_answer_service_server::GptAnswerServiceServer; +use common::kill_signals; use common::loggers::telemetry::init_telemetry; use common::options::parse_options; use gpt_answer_server::controllers::gpt_answer::GptAnswerServiceImpl; use gpt_answer_server::options::Options; -pub async fn serve(options: Options) { +pub async fn serve(options: Options, rx: Receiver<()>) { let address = options.server_endpoint.parse().unwrap(); println!("Starting GPT Answer server at {}", options.server_endpoint); let gpt_answer_service = GptAnswerServiceImpl::new("dummy_prop".to_string()); Server::builder() .add_service(GptAnswerServiceServer::new(gpt_answer_service)) - .serve(address) + .serve_with_shutdown(address, async { + rx.await.ok(); + info!("GRPC server shut down"); + }) .await .unwrap(); } @@ -47,11 +55,17 @@ async fn main() { options.log.level.as_str(), ); - let server = tokio::spawn(serve(options)); + let (tx, rx) = oneshot::channel(); + let server = tokio::spawn(serve(options, rx)); - tokio::try_join!(server).expect("Failed to run servers"); + kill_signals::wait_for_kill_signals().await; + + // Send the shutdown signal + let _ = tx.send(()); + tokio::try_join!(server).expect("Failed to run server"); global::shutdown_tracer_provider(); + info!("Shutdown successfully!"); } /// GPT Answer GRPC server. diff --git a/src/public/src/main.rs b/src/public/src/main.rs index bb1df73..5339169 100644 --- a/src/public/src/main.rs +++ b/src/public/src/main.rs @@ -12,6 +12,9 @@ use clap::{Parser, Subcommand}; use deadpool_diesel::postgres::Pool; use deadpool_diesel::{Manager, Runtime}; use opentelemetry::global; +use tokio::signal; +use tokio::sync::oneshot; +use tokio::sync::oneshot::Receiver; use tracing::info; use adapter::repositories::grpc::gpt_answer_client::GptAnswerClient; @@ -19,6 +22,7 @@ use adapter::repositories::in_memory::question::QuestionInMemoryRepository; use adapter::repositories::postgres::question_db::QuestionDBRepository; use cli::options::Options; use cli::router::Router; +use common::kill_signals; use common::loggers::telemetry::init_telemetry; use common::options::parse_options; use rust_core::ports::question::QuestionPort; @@ -50,10 +54,19 @@ async fn main() { options.log.level.as_str(), ); - let server = tokio::spawn(serve(options)); - tokio::try_join!(server).expect("Failed to run servers"); + let (tx, rx) = oneshot::channel(); + let server = tokio::spawn(serve(options, rx)); + + kill_signals::wait_for_kill_signals().await; + + // Send the shutdown signal + let _ = tx.send(()); + + // Wait for the server to finish shutting down + tokio::try_join!(server).expect("Failed to run server"); global::shutdown_tracer_provider(); + info!("Shutdown successfully!"); } /// Simple REST server. @@ -76,7 +89,7 @@ enum Commands { Config, } -pub async fn serve(options: Options) { +pub async fn serve(options: Options, rx: Receiver<()>) { let question_port: Arc = if options.db.in_memory.is_some() { info!("Using in-memory database"); Arc::new(QuestionInMemoryRepository::new()) @@ -103,6 +116,10 @@ pub async fn serve(options: Options) { Ipv4Addr::from_str(options.server.url.as_str()).unwrap(), options.server.port, ); + let (_, server) = warp::serve(routes).bind_with_graceful_shutdown(address, async { + rx.await.ok(); + info!("Warp server shut down"); + }); - warp::serve(routes).run(address).await + server.await; }