diff --git a/crates/cli/src/subcommands/subscribe.rs b/crates/cli/src/subcommands/subscribe.rs index 5f61908ebfa..bab3ea925ce 100644 --- a/crates/cli/src/subcommands/subscribe.rs +++ b/crates/cli/src/subcommands/subscribe.rs @@ -9,10 +9,12 @@ use spacetimedb_data_structures::map::HashMap; use spacetimedb_lib::db::raw_def::v9::RawModuleDefV9; use spacetimedb_lib::de::serde::{DeserializeWrapper, SeedWrapper}; use spacetimedb_lib::ser::serde::SerializeWrapper; +use std::io; use std::time::Duration; +use thiserror::Error; use tokio::io::AsyncWriteExt; use tokio_tungstenite::tungstenite::client::IntoClientRequest; -use tokio_tungstenite::tungstenite::Message as WsMessage; +use tokio_tungstenite::tungstenite::{Error as WsError, Message as WsMessage}; use crate::api::ClientApi; use crate::common_args; @@ -155,7 +157,7 @@ pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error if let Some(auth_header) = api.con.auth_header.to_header() { req.headers_mut().insert(header::AUTHORIZATION, auth_header); } - let (mut ws, _) = tokio_tungstenite::connect_async(req).await?; + let mut ws = tokio_tungstenite::connect_async(req).await.map(|(ws, _)| ws)?; let task = async { subscribe(&mut ws, queries.cloned().map(Into::into).collect()).await?; @@ -163,27 +165,72 @@ pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error consume_transaction_updates(&mut ws, num, &module_def).await }; - let needs_shutdown = if let Some(timeout) = timeout { + let res = if let Some(timeout) = timeout { let timeout = Duration::from_secs(timeout.into()); match tokio::time::timeout(timeout, task).await { - Ok(res) => res?, - Err(_elapsed) => true, + Ok(res) => res, + Err(_elapsed) => { + eprintln!("timed out after {}s", timeout.as_secs()); + Ok(()) + } } } else { - task.await? + task.await }; - if needs_shutdown { - ws.close(None).await?; + // Close the connection gracefully, unless it's a websocket error, + // in which case the connection is most likely already unusable. + if !matches!(res, Err(Error::Subscribe { .. } | Error::Websocket { .. })) { + // Ignore errors here, we're going to drop the connection anyways. + let _ = ws.close(None).await; } - Ok(()) + res.or_else(|e| if e.is_closed_normally() { Ok(()) } else { Err(e) }) + .map_err(anyhow::Error::from) +} + +#[derive(Debug, Error)] +enum Error { + #[error("error sending subscription queries")] + Subscribe { + #[source] + source: WsError, + }, + #[error("protocol error: {details}")] + Protocol { details: &'static str }, + #[error("websocket error: {source}")] + Websocket { + #[source] + source: WsError, + }, + #[error("encountered failed transaction: {reason}")] + TransactionFailure { reason: Box }, + #[error("error formatting response: {source:#}")] + Reformat { + #[source] + source: anyhow::Error, + }, + #[error(transparent)] + Serde(#[from] serde_json::Error), + #[error(transparent)] + Io(#[from] io::Error), +} + +impl Error { + fn is_closed_normally(&self) -> bool { + matches!( + self, + Self::Websocket { + source: WsError::ConnectionClosed + } + ) + } } /// Send the subscribe message. -async fn subscribe(ws: &mut S, query_strings: Box<[Box]>) -> Result<(), S::Error> +async fn subscribe(ws: &mut S, query_strings: Box<[Box]>) -> Result<(), Error> where - S: Sink + Unpin, + S: Sink + Unpin, { let msg = serde_json::to_string(&SerializeWrapper::new(ws::ClientMessage::<()>::Subscribe( ws::Subscribe { @@ -192,35 +239,39 @@ where }, ))) .unwrap(); - ws.send(msg.into()).await + ws.send(msg.into()).await.map_err(|source| Error::Subscribe { source }) } /// Await the initial [`ServerMessage::SubscriptionUpdate`]. /// If `module_def` is `Some`, print a JSON representation to stdout. -async fn await_initial_update(ws: &mut S, module_def: Option<&RawModuleDefV9>) -> anyhow::Result<()> +async fn await_initial_update(ws: &mut S, module_def: Option<&RawModuleDefV9>) -> Result<(), Error> where - S: TryStream + Unpin, - S::Error: std::error::Error + Send + Sync + 'static, + S: TryStream + Unpin, { const RECV_TX_UPDATE: &str = "protocol error: received transaction update before initial subscription update"; - while let Some(msg) = ws.try_next().await? { + while let Some(msg) = ws.try_next().await.map_err(|source| Error::Websocket { source })? { let Some(msg) = parse_msg_json(&msg) else { continue }; match msg { ws::ServerMessage::InitialSubscription(sub) => { if let Some(module_def) = module_def { - let formatted = reformat_update(&sub.database_update, module_def)?; - let output = serde_json::to_string(&formatted)? + "\n"; + let output = format_output_json(&sub.database_update, module_def)?; tokio::io::stdout().write_all(output.as_bytes()).await? } break; } - ws::ServerMessage::TransactionUpdate(ws::TransactionUpdate { status, .. }) => anyhow::bail!(match status { - ws::UpdateStatus::Failed(msg) => msg, - _ => RECV_TX_UPDATE.into(), - }), + ws::ServerMessage::TransactionUpdate(ws::TransactionUpdate { status, .. }) => { + return Err(match status { + ws::UpdateStatus::Failed(msg) => Error::TransactionFailure { reason: msg }, + _ => Error::Protocol { + details: RECV_TX_UPDATE, + }, + }) + } ws::ServerMessage::TransactionUpdateLight(ws::TransactionUpdateLight { .. }) => { - anyhow::bail!(RECV_TX_UPDATE) + return Err(Error::Protocol { + details: RECV_TX_UPDATE, + }) } _ => continue, } @@ -231,37 +282,36 @@ where /// Print `num` [`ServerMessage::TransactionUpdate`] messages as JSON. /// If `num` is `None`, keep going indefinitely. -async fn consume_transaction_updates( - ws: &mut S, - num: Option, - module_def: &RawModuleDefV9, -) -> anyhow::Result +async fn consume_transaction_updates(ws: &mut S, num: Option, module_def: &RawModuleDefV9) -> Result<(), Error> where - S: TryStream + Unpin, - S::Error: std::error::Error + Send + Sync + 'static, + S: TryStream + Unpin, { let mut stdout = tokio::io::stdout(); let mut num_received = 0; loop { if num.is_some_and(|n| num_received >= n) { - break Ok(true); + return Ok(()); } - let Some(msg) = ws.try_next().await? else { + let Some(msg) = ws.try_next().await.map_err(|source| Error::Websocket { source })? else { eprintln!("disconnected by server"); - break Ok(false); + return Err(Error::Websocket { + source: WsError::ConnectionClosed, + }); }; let Some(msg) = parse_msg_json(&msg) else { continue }; match msg { ws::ServerMessage::InitialSubscription(_) => { - anyhow::bail!("protocol error: received a second initial subscription update") + return Err(Error::Protocol { + details: "received a second initial subscription update", + }) } ws::ServerMessage::TransactionUpdateLight(ws::TransactionUpdateLight { update, .. }) | ws::ServerMessage::TransactionUpdate(ws::TransactionUpdate { status: ws::UpdateStatus::Committed(update), .. }) => { - let output = serde_json::to_string(&reformat_update(&update, module_def)?)? + "\n"; + let output = format_output_json(&update, module_def)?; stdout.write_all(output.as_bytes()).await?; num_received += 1; } @@ -269,3 +319,10 @@ where } } } + +fn format_output_json(msg: &ws::DatabaseUpdate, schema: &RawModuleDefV9) -> Result { + let formatted = reformat_update(msg, schema).map_err(|source| Error::Reformat { source })?; + let output = serde_json::to_string(&formatted)? + "\n"; + + Ok(output) +} diff --git a/smoketests/config.toml b/smoketests/config.toml index b8fa63b6df8..b7c4ad31a45 100644 --- a/smoketests/config.toml +++ b/smoketests/config.toml @@ -1,5 +1,5 @@ default_server = "127.0.0.1:3000" -spacetimedb_token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiJ9.eyJoZXhfaWRlbnRpdHkiOiJjMjAwYTAwNzQ5ZjM0MTQyZmQzYjk3YmZmNjE1OWNkZjg2YmNkYzI2ZGZkMDgwYjcxZWVkYWY0MTcxYmYxMjg5Iiwic3ViIjoiOGZhY2JlN2ItNzg4NS00MmMyLTg3NTktN2M4NGJmNWMyMGU1IiwiaXNzIjoibG9jYWxob3N0IiwiYXVkIjpbInNwYWNldGltZWRiIl0sImlhdCI6MTc0MzQzNzEwMywiZXhwIjpudWxsfQ.ON0q_bu6WLWzDWh5AQ4b601spdZ46qKWg6SWHd9IcoLi7iRx-Jr4z5XnZpkkdcSWOQ4FU81ewn5JmvScoQrOPg" +spacetimedb_token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiJ9.eyJoZXhfaWRlbnRpdHkiOiJjMjAwYzc3NDY1NTE5MDM2MTE4M2JiNjFmMWMxYzY3NDUzMzYzY2MxMTY4MmM1NTUwNWZiNjdlYzI0ZWMyMWViIiwic3ViIjoiOTJlMmNkOGQtNTk5Ny00NjZlLWIwNmYtZDNjOGQ1NzU3ODI4IiwiaXNzIjoibG9jYWxob3N0IiwiYXVkIjpbInNwYWNldGltZWRiIl0sImlhdCI6MTc1MjA0NjgwMCwiZXhwIjpudWxsfQ.dgefoxC7eCOONVUufu2JTVFo9876zQ4Mqwm0ivZ0PQK7Hacm3Ip_xqyav4bilZ0vIEf8IM8AB0_xawk8WcbvMg" [[server_configs]] nickname = "localhost"