Skip to content

cli: Close the websocket connection gracefully #2925

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
30 changes: 15 additions & 15 deletions crates/cli/src/subcommands/subscribe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use spacetimedb_lib::ser::serde::SerializeWrapper;
use std::time::Duration;
use tokio::io::AsyncWriteExt;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::Message as WsMessage;
use tokio_tungstenite::tungstenite::{Error as WsError, Message as WsMessage};

use crate::api::ClientApi;
use crate::common_args;
Expand Down Expand Up @@ -158,29 +158,33 @@ pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error
if let Some(auth_header) = api.con.auth_header.to_header() {
req.headers_mut().insert(header::AUTHORIZATION, auth_header);
}
let (mut ws, _) = tokio_tungstenite::connect_async(req).await?;
let mut ws = tokio_tungstenite::connect_async(req).await.map(|(ws, _)| ws)?;

let task = async {
subscribe(&mut ws, queries.cloned().map(Into::into).collect()).await?;
await_initial_update(&mut ws, print_initial_update.then_some(&module_def)).await?;
consume_transaction_updates(&mut ws, num, &module_def).await
};

let needs_shutdown = if let Some(timeout) = timeout {
let res = if let Some(timeout) = timeout {
let timeout = Duration::from_secs(timeout.into());
match tokio::time::timeout(timeout, task).await {
Ok(res) => res?,
Err(_elapsed) => true,
Ok(res) => res,
Err(_elapsed) => {
eprintln!("timed out after {}s", timeout.as_secs());
Ok(())
}
}
} else {
task.await?
task.await
};

if needs_shutdown {
// Close the connection gracefully, unless it's a websocket error.
if !res.as_ref().is_err_and(|e| e.downcast_ref::<WsError>().is_some()) {
ws.close(None).await?;
}

Ok(())
res
}

/// Send the subscribe message.
Expand Down Expand Up @@ -234,11 +238,7 @@ where

/// Print `num` [`ServerMessage::TransactionUpdate`] messages as JSON.
/// If `num` is `None`, keep going indefinitely.
async fn consume_transaction_updates<S>(
ws: &mut S,
num: Option<u32>,
module_def: &RawModuleDefV9,
) -> anyhow::Result<bool>
async fn consume_transaction_updates<S>(ws: &mut S, num: Option<u32>, module_def: &RawModuleDefV9) -> anyhow::Result<()>
where
S: TryStream<Ok = WsMessage> + Unpin,
S::Error: std::error::Error + Send + Sync + 'static,
Expand All @@ -247,11 +247,11 @@ where
let mut num_received = 0;
loop {
if num.is_some_and(|n| num_received >= n) {
break Ok(true);
break Ok(());
}
let Some(msg) = ws.try_next().await? else {
eprintln!("disconnected by server");
break Ok(false);
break Ok(());
};

let Some(msg) = parse_msg_json(&msg) else { continue };
Expand Down
Loading