Skip to content

cli: Close the websocket connection gracefully #2925

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 92 additions & 35 deletions crates/cli/src/subcommands/subscribe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ use spacetimedb_data_structures::map::HashMap;
use spacetimedb_lib::db::raw_def::v9::RawModuleDefV9;
use spacetimedb_lib::de::serde::{DeserializeWrapper, SeedWrapper};
use spacetimedb_lib::ser::serde::SerializeWrapper;
use std::io;
use std::time::Duration;
use thiserror::Error;
use tokio::io::AsyncWriteExt;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::Message as WsMessage;
use tokio_tungstenite::tungstenite::{Error as WsError, Message as WsMessage};

use crate::api::ClientApi;
use crate::common_args;
Expand Down Expand Up @@ -155,35 +157,80 @@ pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error
if let Some(auth_header) = api.con.auth_header.to_header() {
req.headers_mut().insert(header::AUTHORIZATION, auth_header);
}
let (mut ws, _) = tokio_tungstenite::connect_async(req).await?;
let mut ws = tokio_tungstenite::connect_async(req).await.map(|(ws, _)| ws)?;

let task = async {
subscribe(&mut ws, queries.cloned().map(Into::into).collect()).await?;
await_initial_update(&mut ws, print_initial_update.then_some(&module_def)).await?;
consume_transaction_updates(&mut ws, num, &module_def).await
};

let needs_shutdown = if let Some(timeout) = timeout {
let res = if let Some(timeout) = timeout {
let timeout = Duration::from_secs(timeout.into());
match tokio::time::timeout(timeout, task).await {
Ok(res) => res?,
Err(_elapsed) => true,
Ok(res) => res,
Err(_elapsed) => {
eprintln!("timed out after {}s", timeout.as_secs());
Ok(())
}
}
} else {
task.await?
task.await
};

if needs_shutdown {
ws.close(None).await?;
// Close the connection gracefully, unless it's a websocket error,
// in which case the connection is most likely already unusable.
if !matches!(res, Err(Error::Subscribe { .. } | Error::Websocket { .. })) {
// Ignore errors here, we're going to drop the connection anyways.
let _ = ws.close(None).await;
}

Ok(())
res.or_else(|e| if e.is_closed_normally() { Ok(()) } else { Err(e) })
.map_err(anyhow::Error::from)
}

#[derive(Debug, Error)]
enum Error {
#[error("error sending subscription queries")]
Subscribe {
#[source]
source: WsError,
},
#[error("protocol error: {details}")]
Protocol { details: &'static str },
#[error("websocket error: {source}")]
Websocket {
#[source]
source: WsError,
},
#[error("encountered failed transaction: {reason}")]
TransactionFailure { reason: Box<str> },
#[error("error formatting response: {source:#}")]
Reformat {
#[source]
source: anyhow::Error,
},
#[error(transparent)]
Serde(#[from] serde_json::Error),
#[error(transparent)]
Io(#[from] io::Error),
}

impl Error {
fn is_closed_normally(&self) -> bool {
matches!(
self,
Self::Websocket {
source: WsError::ConnectionClosed
}
)
}
}

/// Send the subscribe message.
async fn subscribe<S>(ws: &mut S, query_strings: Box<[Box<str>]>) -> Result<(), S::Error>
async fn subscribe<S>(ws: &mut S, query_strings: Box<[Box<str>]>) -> Result<(), Error>
where
S: Sink<WsMessage> + Unpin,
S: Sink<WsMessage, Error = WsError> + Unpin,
{
let msg = serde_json::to_string(&SerializeWrapper::new(ws::ClientMessage::<()>::Subscribe(
ws::Subscribe {
Expand All @@ -192,35 +239,39 @@ where
},
)))
.unwrap();
ws.send(msg.into()).await
ws.send(msg.into()).await.map_err(|source| Error::Subscribe { source })
}

/// Await the initial [`ServerMessage::SubscriptionUpdate`].
/// If `module_def` is `Some`, print a JSON representation to stdout.
async fn await_initial_update<S>(ws: &mut S, module_def: Option<&RawModuleDefV9>) -> anyhow::Result<()>
async fn await_initial_update<S>(ws: &mut S, module_def: Option<&RawModuleDefV9>) -> Result<(), Error>
where
S: TryStream<Ok = WsMessage> + Unpin,
S::Error: std::error::Error + Send + Sync + 'static,
S: TryStream<Ok = WsMessage, Error = WsError> + Unpin,
{
const RECV_TX_UPDATE: &str = "protocol error: received transaction update before initial subscription update";

while let Some(msg) = ws.try_next().await? {
while let Some(msg) = ws.try_next().await.map_err(|source| Error::Websocket { source })? {
let Some(msg) = parse_msg_json(&msg) else { continue };
match msg {
ws::ServerMessage::InitialSubscription(sub) => {
if let Some(module_def) = module_def {
let formatted = reformat_update(&sub.database_update, module_def)?;
let output = serde_json::to_string(&formatted)? + "\n";
let output = format_output_json(&sub.database_update, module_def)?;
tokio::io::stdout().write_all(output.as_bytes()).await?
}
break;
}
ws::ServerMessage::TransactionUpdate(ws::TransactionUpdate { status, .. }) => anyhow::bail!(match status {
ws::UpdateStatus::Failed(msg) => msg,
_ => RECV_TX_UPDATE.into(),
}),
ws::ServerMessage::TransactionUpdate(ws::TransactionUpdate { status, .. }) => {
return Err(match status {
ws::UpdateStatus::Failed(msg) => Error::TransactionFailure { reason: msg },
_ => Error::Protocol {
details: RECV_TX_UPDATE,
},
})
}
ws::ServerMessage::TransactionUpdateLight(ws::TransactionUpdateLight { .. }) => {
anyhow::bail!(RECV_TX_UPDATE)
return Err(Error::Protocol {
details: RECV_TX_UPDATE,
})
}
_ => continue,
}
Expand All @@ -231,41 +282,47 @@ where

/// Print `num` [`ServerMessage::TransactionUpdate`] messages as JSON.
/// If `num` is `None`, keep going indefinitely.
async fn consume_transaction_updates<S>(
ws: &mut S,
num: Option<u32>,
module_def: &RawModuleDefV9,
) -> anyhow::Result<bool>
async fn consume_transaction_updates<S>(ws: &mut S, num: Option<u32>, module_def: &RawModuleDefV9) -> Result<(), Error>
where
S: TryStream<Ok = WsMessage> + Unpin,
S::Error: std::error::Error + Send + Sync + 'static,
S: TryStream<Ok = WsMessage, Error = WsError> + Unpin,
{
let mut stdout = tokio::io::stdout();
let mut num_received = 0;
loop {
if num.is_some_and(|n| num_received >= n) {
break Ok(true);
return Ok(());
}
let Some(msg) = ws.try_next().await? else {
let Some(msg) = ws.try_next().await.map_err(|source| Error::Websocket { source })? else {
eprintln!("disconnected by server");
break Ok(false);
return Err(Error::Websocket {
source: WsError::ConnectionClosed,
});
};

let Some(msg) = parse_msg_json(&msg) else { continue };
match msg {
ws::ServerMessage::InitialSubscription(_) => {
anyhow::bail!("protocol error: received a second initial subscription update")
return Err(Error::Protocol {
details: "received a second initial subscription update",
})
}
ws::ServerMessage::TransactionUpdateLight(ws::TransactionUpdateLight { update, .. })
| ws::ServerMessage::TransactionUpdate(ws::TransactionUpdate {
status: ws::UpdateStatus::Committed(update),
..
}) => {
let output = serde_json::to_string(&reformat_update(&update, module_def)?)? + "\n";
let output = format_output_json(&update, module_def)?;
stdout.write_all(output.as_bytes()).await?;
num_received += 1;
}
_ => continue,
}
}
}

fn format_output_json(msg: &ws::DatabaseUpdate<JsonFormat>, schema: &RawModuleDefV9) -> Result<String, Error> {
let formatted = reformat_update(msg, schema).map_err(|source| Error::Reformat { source })?;
let output = serde_json::to_string(&formatted)? + "\n";

Ok(output)
}
2 changes: 1 addition & 1 deletion smoketests/config.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
default_server = "127.0.0.1:3000"
spacetimedb_token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiJ9.eyJoZXhfaWRlbnRpdHkiOiJjMjAwYTAwNzQ5ZjM0MTQyZmQzYjk3YmZmNjE1OWNkZjg2YmNkYzI2ZGZkMDgwYjcxZWVkYWY0MTcxYmYxMjg5Iiwic3ViIjoiOGZhY2JlN2ItNzg4NS00MmMyLTg3NTktN2M4NGJmNWMyMGU1IiwiaXNzIjoibG9jYWxob3N0IiwiYXVkIjpbInNwYWNldGltZWRiIl0sImlhdCI6MTc0MzQzNzEwMywiZXhwIjpudWxsfQ.ON0q_bu6WLWzDWh5AQ4b601spdZ46qKWg6SWHd9IcoLi7iRx-Jr4z5XnZpkkdcSWOQ4FU81ewn5JmvScoQrOPg"
spacetimedb_token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiJ9.eyJoZXhfaWRlbnRpdHkiOiJjMjAwYzc3NDY1NTE5MDM2MTE4M2JiNjFmMWMxYzY3NDUzMzYzY2MxMTY4MmM1NTUwNWZiNjdlYzI0ZWMyMWViIiwic3ViIjoiOTJlMmNkOGQtNTk5Ny00NjZlLWIwNmYtZDNjOGQ1NzU3ODI4IiwiaXNzIjoibG9jYWxob3N0IiwiYXVkIjpbInNwYWNldGltZWRiIl0sImlhdCI6MTc1MjA0NjgwMCwiZXhwIjpudWxsfQ.dgefoxC7eCOONVUufu2JTVFo9876zQ4Mqwm0ivZ0PQK7Hacm3Ip_xqyav4bilZ0vIEf8IM8AB0_xawk8WcbvMg"

[[server_configs]]
nickname = "localhost"
Expand Down
Loading