diff --git a/Cargo.lock b/Cargo.lock index c9bd4e05985..cdd72f11363 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5399,6 +5399,7 @@ dependencies = [ "spacetimedb-sats", "strum", "thiserror 1.0.69", + "zstd", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 3159af03756..efd5de964cf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -290,6 +290,7 @@ xdg = "2.5" tikv-jemallocator = { version = "0.6.0", features = ["profiling", "stats"] } tikv-jemalloc-ctl = { version = "0.6.0", features = ["stats"] } jemalloc_pprof = { version = "0.7", features = ["symbolize", "flamegraph"] } +zstd = "0.13" zstd-framed = { version = "0.1.1", features = ["tokio"] } # Vendor the openssl we rely on, rather than depend on a diff --git a/crates/client-api-messages/Cargo.toml b/crates/client-api-messages/Cargo.toml index 90f767ee364..c146d63e4ff 100644 --- a/crates/client-api-messages/Cargo.toml +++ b/crates/client-api-messages/Cargo.toml @@ -13,6 +13,7 @@ spacetimedb-sats = { workspace = true, features = ["bytestring"] } bytes.workspace = true bytestring.workspace = true brotli.workspace = true +zstd.workspace = true chrono = { workspace = true, features = ["serde"] } enum-as-inner.workspace = true flate2.workspace = true diff --git a/crates/client-api-messages/src/websocket.rs b/crates/client-api-messages/src/websocket.rs index 2292f4b12fe..570a5e9abff 100644 --- a/crates/client-api-messages/src/websocket.rs +++ b/crates/client-api-messages/src/websocket.rs @@ -306,12 +306,15 @@ pub struct OneOffQuery { /// The tag recognized by the host and SDKs to mean no compression of a [`ServerMessage`]. pub const SERVER_MSG_COMPRESSION_TAG_NONE: u8 = 0; -/// The tag recognized by the host and SDKs to mean brotli compression of a [`ServerMessage`]. +/// The tag recognized by the host and SDKs to mean brotli compression of a [`ServerMessage`]. pub const SERVER_MSG_COMPRESSION_TAG_BROTLI: u8 = 1; -/// The tag recognized by the host and SDKs to mean brotli compression of a [`ServerMessage`]. +/// The tag recognized by the host and SDKs to mean gzip compression of a [`ServerMessage`]. pub const SERVER_MSG_COMPRESSION_TAG_GZIP: u8 = 2; +/// The tag recognized by the host and SDKs to mean zstd compression of a [`ServerMessage`]. +pub const SERVER_MSG_COMPRESSION_TAG_ZSTD: u8 = 3; + /// Messages sent from the server to the client. #[derive(SpacetimeType, derive_more::From)] #[sats(crate = spacetimedb_lib)] @@ -664,6 +667,7 @@ pub enum CompressableQueryUpdate { Uncompressed(QueryUpdate), Brotli(Bytes), Gzip(Bytes), + Zstd(Bytes), } impl CompressableQueryUpdate { @@ -678,6 +682,10 @@ impl CompressableQueryUpdate { let bytes = gzip_decompress(&bytes).unwrap(); bsatn::from_slice(&bytes).unwrap() } + Self::Zstd(bytes) => { + let bytes = zstd_decompress(&bytes).unwrap(); + bsatn::from_slice(&bytes).unwrap() + } } } } @@ -830,6 +838,12 @@ impl WebsocketFormat for BsatnFormat { gzip_compress(&bytes, &mut out); CompressableQueryUpdate::Gzip(out.into()) } + Compression::Zstd => { + let bytes = bsatn::to_vec(&qu).unwrap(); + let mut out = Vec::new(); + zstd_compress(&bytes, &mut out); + CompressableQueryUpdate::Zstd(out.into()) + } } } } @@ -844,6 +858,8 @@ pub enum Compression { Brotli, /// Compress using gzip if a certain size threshold was met. Gzip, + /// Compress using zstd if a certain size threshold was met. + Zstd, } pub fn decide_compression(len: usize, compression: Compression) -> Compression { @@ -898,6 +914,18 @@ pub fn gzip_decompress(bytes: &[u8]) -> Result, io::Error> { Ok(decompressed) } +pub fn zstd_compress(bytes: &[u8], out: &mut Vec) { + const ZSTD_LEVEL: i32 = 5; + + zstd::stream::copy_encode(bytes, out, ZSTD_LEVEL).expect("Failed to zstd compress `bytes`"); +} + +pub fn zstd_decompress(bytes: &[u8]) -> Result, io::Error> { + let mut decompressed = Vec::new(); + zstd::stream::copy_decode(bytes, &mut decompressed)?; + Ok(decompressed) +} + type RowSize = u16; type RowOffset = u64; diff --git a/crates/core/src/client/messages.rs b/crates/core/src/client/messages.rs index cdd0e3c1c94..b1fb4043917 100644 --- a/crates/core/src/client/messages.rs +++ b/crates/core/src/client/messages.rs @@ -4,10 +4,7 @@ use crate::host::module_host::{EventStatus, ModuleEvent}; use crate::host::ArgsTuple; use crate::messages::websocket as ws; use derive_more::From; -use spacetimedb_client_api_messages::websocket::{ - BsatnFormat, Compression, FormatSwitch, JsonFormat, OneOffTable, RowListLen, WebsocketFormat, - SERVER_MSG_COMPRESSION_TAG_BROTLI, SERVER_MSG_COMPRESSION_TAG_GZIP, SERVER_MSG_COMPRESSION_TAG_NONE, -}; +use spacetimedb_client_api_messages::websocket::{BsatnFormat, Compression, FormatSwitch, JsonFormat, OneOffTable, RowListLen, WebsocketFormat, SERVER_MSG_COMPRESSION_TAG_BROTLI, SERVER_MSG_COMPRESSION_TAG_GZIP, SERVER_MSG_COMPRESSION_TAG_NONE, SERVER_MSG_COMPRESSION_TAG_ZSTD}; use spacetimedb_lib::identity::RequestId; use spacetimedb_lib::ser::serde::SerializeWrapper; use spacetimedb_lib::{ConnectionId, TimeDuration}; @@ -55,6 +52,11 @@ pub fn serialize(msg: impl ToProtocol, config: ws::gzip_compress(srv_msg, &mut out); out } + Compression::Zstd => { + let mut out = vec![SERVER_MSG_COMPRESSION_TAG_ZSTD]; + ws::zstd_compress(srv_msg, &mut out); + out + } }; msg_bytes.into() } diff --git a/crates/sdk/src/websocket.rs b/crates/sdk/src/websocket.rs index 59b5c5e34f5..0d5d908348d 100644 --- a/crates/sdk/src/websocket.rs +++ b/crates/sdk/src/websocket.rs @@ -10,10 +10,7 @@ use bytes::Bytes; use futures::{SinkExt, StreamExt as _, TryStreamExt}; use futures_channel::mpsc; use http::uri::{InvalidUri, Scheme, Uri}; -use spacetimedb_client_api_messages::websocket::{ - brotli_decompress, gzip_decompress, BsatnFormat, Compression, BIN_PROTOCOL, SERVER_MSG_COMPRESSION_TAG_BROTLI, - SERVER_MSG_COMPRESSION_TAG_GZIP, SERVER_MSG_COMPRESSION_TAG_NONE, -}; +use spacetimedb_client_api_messages::websocket::{brotli_decompress, gzip_decompress, zstd_decompress, BsatnFormat, Compression, BIN_PROTOCOL, SERVER_MSG_COMPRESSION_TAG_BROTLI, SERVER_MSG_COMPRESSION_TAG_GZIP, SERVER_MSG_COMPRESSION_TAG_NONE, SERVER_MSG_COMPRESSION_TAG_ZSTD}; use spacetimedb_client_api_messages::websocket::{ClientMessage, ServerMessage}; use spacetimedb_lib::{bsatn, ConnectionId}; use thiserror::Error; @@ -145,6 +142,7 @@ fn make_uri(host: Uri, db_name: &str, connection_id: ConnectionId, params: WsPar // The host uses the same default as the sdk, // but in case this changes, we prefer to be explicit now. Compression::Brotli => path.push_str("&compression=Brotli"), + Compression::Zstd => path.push_str("&compression=Zstd"), }; // Specify the `light` mode if requested. @@ -254,6 +252,13 @@ impl WsConnection { })?) .map_err(|source| WsError::DeserializeMessage { source })? } + SERVER_MSG_COMPRESSION_TAG_ZSTD => { + bsatn::from_slice(&zstd_decompress(bytes).map_err(|source| WsError::Decompress { + scheme: "zstd", + source: Arc::new(source), + })?) + .map_err(|source| WsError::DeserializeMessage { source })? + } c => { return Err(WsError::UnknownCompressionScheme { scheme: c }); }