Skip to content

Commit 09717e9

Browse files
authored
Make websocket configurable via config.toml (#2944)
1 parent 88dc369 commit 09717e9

File tree

11 files changed

+235
-51
lines changed

11 files changed

+235
-51
lines changed

Cargo.lock

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/client-api/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,15 @@ jsonwebtoken.workspace = true
5151
scopeguard.workspace = true
5252
serde_with.workspace = true
5353
async-stream.workspace = true
54+
humantime.workspace = true
5455

5556
[target.'cfg(not(target_env = "msvc"))'.dependencies]
5657
jemalloc_pprof.workspace = true
5758

5859
[dev-dependencies]
5960
jsonwebtoken.workspace = true
6061
pretty_assertions = { workspace = true, features = ["unstable"] }
62+
toml.workspace = true
6163

6264
[lints]
6365
workspace = true

crates/client-api/src/routes/database.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ use spacetimedb_lib::db::raw_def::v9::RawModuleDefV9;
3131
use spacetimedb_lib::identity::AuthCtx;
3232
use spacetimedb_lib::{sats, Timestamp};
3333

34-
use super::subscribe::handle_websocket;
34+
use super::subscribe::{handle_websocket, HasWebSocketOptions};
3535

3636
#[derive(Deserialize)]
3737
pub struct CallParams {
@@ -790,7 +790,7 @@ pub struct DatabaseRoutes<S> {
790790

791791
impl<S> Default for DatabaseRoutes<S>
792792
where
793-
S: NodeDelegate + ControlStateDelegate + Clone + 'static,
793+
S: NodeDelegate + ControlStateDelegate + HasWebSocketOptions + Clone + 'static,
794794
{
795795
fn default() -> Self {
796796
use axum::routing::{delete, get, post, put};

crates/client-api/src/routes/subscribe.rs

Lines changed: 99 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ use tokio::time::{sleep_until, timeout};
4444
use tokio_tungstenite::tungstenite::Utf8Bytes;
4545

4646
use crate::auth::SpacetimeAuth;
47+
use crate::util::serde::humantime_duration;
4748
use crate::util::websocket::{
4849
CloseCode, CloseFrame, Message as WsMessage, WebSocketConfig, WebSocketStream, WebSocketUpgrade, WsError,
4950
};
@@ -55,6 +56,16 @@ pub const TEXT_PROTOCOL: HeaderValue = HeaderValue::from_static(ws_api::TEXT_PRO
5556
#[allow(clippy::declare_interior_mutable_const)]
5657
pub const BIN_PROTOCOL: HeaderValue = HeaderValue::from_static(ws_api::BIN_PROTOCOL);
5758

59+
pub trait HasWebSocketOptions {
60+
fn websocket_options(&self) -> WebSocketOptions;
61+
}
62+
63+
impl<T: HasWebSocketOptions> HasWebSocketOptions for Arc<T> {
64+
fn websocket_options(&self) -> WebSocketOptions {
65+
(**self).websocket_options()
66+
}
67+
}
68+
5869
#[derive(Deserialize)]
5970
pub struct SubscribeParams {
6071
pub name_or_identity: NameOrIdentity,
@@ -88,7 +99,7 @@ pub async fn handle_websocket<S>(
8899
ws: WebSocketUpgrade,
89100
) -> axum::response::Result<impl IntoResponse>
90101
where
91-
S: NodeDelegate + ControlStateDelegate,
102+
S: NodeDelegate + ControlStateDelegate + HasWebSocketOptions,
92103
{
93104
if connection_id.is_some() {
94105
// TODO: Bump this up to `log::warn!` after removing the client SDKs' uses of that parameter.
@@ -146,6 +157,7 @@ where
146157
.max_message_size(Some(0x2000000))
147158
.max_frame_size(None)
148159
.accept_unmasked_frames(false);
160+
let ws_opts = ctx.websocket_options();
149161

150162
tokio::spawn(async move {
151163
let ws = match ws_upgrade.upgrade(ws_config).await {
@@ -163,7 +175,7 @@ where
163175
None => log::debug!("New client connected from unknown ip"),
164176
}
165177

166-
let actor = |client, sendrx| ws_client_actor(client, ws, sendrx);
178+
let actor = |client, sendrx| ws_client_actor(ws_opts, client, ws, sendrx);
167179
let client = match ClientConnection::spawn(client_id, client_config, leader.replica_id, module_rx, actor).await
168180
{
169181
Ok(s) => s,
@@ -198,13 +210,13 @@ where
198210
struct ActorState {
199211
pub client_id: ClientActorId,
200212
pub database: Identity,
201-
config: ActorConfig,
213+
config: WebSocketOptions,
202214
closed: AtomicBool,
203215
got_pong: AtomicBool,
204216
}
205217

206218
impl ActorState {
207-
pub fn new(database: Identity, client_id: ClientActorId, config: ActorConfig) -> Self {
219+
pub fn new(database: Identity, client_id: ClientActorId, config: WebSocketOptions) -> Self {
208220
Self {
209221
database,
210222
client_id,
@@ -235,14 +247,19 @@ impl ActorState {
235247
}
236248
}
237249

238-
struct ActorConfig {
250+
/// Configuration for WebSocket connections.
251+
#[derive(Clone, Copy, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
252+
#[serde(rename_all = "kebab-case")]
253+
pub struct WebSocketOptions {
239254
/// Interval at which to send `Ping` frames.
240255
///
241256
/// We use pings for connection keep-alive.
242257
/// Value must be smaller than `idle_timeout`.
243258
///
244259
/// Default: 15s
245-
ping_interval: Duration,
260+
#[serde(with = "humantime_duration")]
261+
#[serde(default = "WebSocketOptions::default_ping_interval")]
262+
pub ping_interval: Duration,
246263
/// Amount of time after which an idle connection is closed.
247264
///
248265
/// A connection is considered idle if no data is received nor sent.
@@ -251,47 +268,80 @@ struct ActorConfig {
251268
/// Value must be greater than `ping_interval`.
252269
///
253270
/// Default: 30s
254-
idle_timeout: Duration,
271+
#[serde(with = "humantime_duration")]
272+
#[serde(default = "WebSocketOptions::default_idle_timeout")]
273+
pub idle_timeout: Duration,
255274
/// For how long to keep draining the incoming messages until a client close
256275
/// is received.
257276
///
258277
/// Default: 250ms
259-
close_handshake_timeout: Duration,
278+
#[serde(with = "humantime_duration")]
279+
#[serde(default = "WebSocketOptions::default_close_handshake_timeout")]
280+
pub close_handshake_timeout: Duration,
260281
/// Maximum number of messages to queue for processing.
261282
///
262283
/// If this number is exceeded, the client is disconnected.
263284
///
264285
/// Default: 2048
265-
incoming_queue_length: NonZeroUsize,
286+
#[serde(default = "WebSocketOptions::default_incoming_queue_length")]
287+
pub incoming_queue_length: NonZeroUsize,
266288
}
267289

268-
impl Default for ActorConfig {
290+
impl Default for WebSocketOptions {
269291
fn default() -> Self {
270-
Self {
271-
ping_interval: Duration::from_secs(15),
272-
idle_timeout: Duration::from_secs(30),
273-
close_handshake_timeout: Duration::from_millis(250),
274-
incoming_queue_length:
275-
// SAFETY: 2048 > 0, qed
276-
unsafe { NonZeroUsize::new_unchecked(2048) }
277-
}
292+
Self::DEFAULT
278293
}
279294
}
280295

281-
async fn ws_client_actor(client: ClientConnection, ws: WebSocketStream, sendrx: MeteredReceiver<SerializableMessage>) {
296+
impl WebSocketOptions {
297+
const DEFAULT_PING_INTERVAL: Duration = Duration::from_secs(15);
298+
const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(30);
299+
const DEFAULT_CLOSE_HANDSHAKE_TIMEOUT: Duration = Duration::from_millis(250);
300+
const DEFAULT_INCOMING_QUEUE_LENGTH: NonZeroUsize = NonZeroUsize::new(2048).expect("2048 > 0, qed");
301+
302+
const DEFAULT: Self = Self {
303+
ping_interval: Self::DEFAULT_PING_INTERVAL,
304+
idle_timeout: Self::DEFAULT_IDLE_TIMEOUT,
305+
close_handshake_timeout: Self::DEFAULT_CLOSE_HANDSHAKE_TIMEOUT,
306+
incoming_queue_length: Self::DEFAULT_INCOMING_QUEUE_LENGTH,
307+
};
308+
309+
const fn default_ping_interval() -> Duration {
310+
Self::DEFAULT_PING_INTERVAL
311+
}
312+
313+
const fn default_idle_timeout() -> Duration {
314+
Self::DEFAULT_IDLE_TIMEOUT
315+
}
316+
317+
const fn default_close_handshake_timeout() -> Duration {
318+
Self::DEFAULT_CLOSE_HANDSHAKE_TIMEOUT
319+
}
320+
321+
const fn default_incoming_queue_length() -> NonZeroUsize {
322+
Self::DEFAULT_INCOMING_QUEUE_LENGTH
323+
}
324+
}
325+
326+
async fn ws_client_actor(
327+
options: WebSocketOptions,
328+
client: ClientConnection,
329+
ws: WebSocketStream,
330+
sendrx: MeteredReceiver<SerializableMessage>,
331+
) {
282332
// ensure that even if this task gets cancelled, we always cleanup the connection
283333
let mut client = scopeguard::guard(client, |client| {
284334
tokio::spawn(client.disconnect());
285335
});
286336

287-
ws_client_actor_inner(&mut client, <_>::default(), ws, sendrx).await;
337+
ws_client_actor_inner(&mut client, options, ws, sendrx).await;
288338

289339
ScopeGuard::into_inner(client).disconnect().await;
290340
}
291341

292342
async fn ws_client_actor_inner(
293343
client: &mut ClientConnection,
294-
config: ActorConfig,
344+
config: WebSocketOptions,
295345
ws: WebSocketStream,
296346
sendrx: MeteredReceiver<SerializableMessage>,
297347
) {
@@ -1160,7 +1210,7 @@ mod tests {
11601210
dummy_actor_state_with_config(<_>::default())
11611211
}
11621212

1163-
fn dummy_actor_state_with_config(config: ActorConfig) -> ActorState {
1213+
fn dummy_actor_state_with_config(config: WebSocketOptions) -> ActorState {
11641214
ActorState::new(Identity::ZERO, dummy_client_id(), config)
11651215
}
11661216

@@ -1482,7 +1532,7 @@ mod tests {
14821532

14831533
#[tokio::test]
14841534
async fn main_loop_terminates_on_idle_timeout() {
1485-
let state = Arc::new(dummy_actor_state_with_config(ActorConfig {
1535+
let state = Arc::new(dummy_actor_state_with_config(WebSocketOptions {
14861536
idle_timeout: Duration::from_millis(10),
14871537
..<_>::default()
14881538
}));
@@ -1520,7 +1570,7 @@ mod tests {
15201570

15211571
#[tokio::test]
15221572
async fn main_loop_keepalive_keeps_alive() {
1523-
let state = Arc::new(dummy_actor_state_with_config(ActorConfig {
1573+
let state = Arc::new(dummy_actor_state_with_config(WebSocketOptions {
15241574
ping_interval: Duration::from_millis(5),
15251575
idle_timeout: Duration::from_millis(10),
15261576
..<_>::default()
@@ -1616,7 +1666,7 @@ mod tests {
16161666

16171667
#[tokio::test]
16181668
async fn recv_queue_sends_close_when_at_capacity() {
1619-
let state = Arc::new(dummy_actor_state_with_config(ActorConfig {
1669+
let state = Arc::new(dummy_actor_state_with_config(WebSocketOptions {
16201670
incoming_queue_length: 10.try_into().unwrap(),
16211671
..<_>::default()
16221672
}));
@@ -1632,7 +1682,7 @@ mod tests {
16321682

16331683
#[tokio::test]
16341684
async fn recv_queue_closes_state_if_sender_gone() {
1635-
let state = Arc::new(dummy_actor_state_with_config(ActorConfig {
1685+
let state = Arc::new(dummy_actor_state_with_config(WebSocketOptions {
16361686
incoming_queue_length: 10.try_into().unwrap(),
16371687
..<_>::default()
16381688
}));
@@ -1695,4 +1745,27 @@ mod tests {
16951745
Poll::Ready(Ok(()))
16961746
}
16971747
}
1748+
1749+
#[test]
1750+
fn options_toml_roundtrip() {
1751+
let options = WebSocketOptions::default();
1752+
let toml = toml::to_string(&options).unwrap();
1753+
assert_eq!(options, toml::from_str::<WebSocketOptions>(&toml).unwrap());
1754+
}
1755+
1756+
#[test]
1757+
fn options_from_partial_toml() {
1758+
let toml = r#"
1759+
ping-interval = "53s"
1760+
idle-timeout = "1m 3s"
1761+
"#;
1762+
1763+
let expected = WebSocketOptions {
1764+
ping_interval: Duration::from_secs(53),
1765+
idle_timeout: Duration::from_secs(63),
1766+
..<_>::default()
1767+
};
1768+
1769+
assert_eq!(expected, toml::from_str(toml).unwrap());
1770+
}
16981771
}

crates/client-api/src/util.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
mod flat_csv;
2+
pub(crate) mod serde;
23
pub mod websocket;
34

45
use core::fmt;
@@ -111,16 +112,16 @@ impl NameOrIdentity {
111112
}
112113
}
113114

114-
impl<'de> serde::Deserialize<'de> for NameOrIdentity {
115+
impl<'de> ::serde::Deserialize<'de> for NameOrIdentity {
115116
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
116117
where
117-
D: serde::Deserializer<'de>,
118+
D: ::serde::Deserializer<'de>,
118119
{
119120
let s = String::deserialize(deserializer)?;
120121
if let Ok(addr) = Identity::from_hex(&s) {
121122
Ok(NameOrIdentity::Identity(IdentityForUrl::from(addr)))
122123
} else {
123-
let name: DatabaseName = s.try_into().map_err(serde::de::Error::custom)?;
124+
let name: DatabaseName = s.try_into().map_err(::serde::de::Error::custom)?;
124125
Ok(NameOrIdentity::Name(name))
125126
}
126127
}

crates/client-api/src/util/serde.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
/// Ser/De of [`std::time::Duration`] via the `humantime` crate.
2+
///
3+
/// Suitable for use with the `#[serde(with)]` annotation.
4+
pub(crate) mod humantime_duration {
5+
use std::time::Duration;
6+
7+
use ::serde::{Deserialize as _, Deserializer, Serialize as _, Serializer};
8+
9+
pub fn serialize<S: Serializer>(duration: &Duration, ser: S) -> Result<S::Ok, S::Error> {
10+
humantime::format_duration(*duration).to_string().serialize(ser)
11+
}
12+
13+
pub fn deserialize<'de, D: Deserializer<'de>>(de: D) -> Result<Duration, D::Error> {
14+
// TODO: `toml` chokes if we try to derserialize to `&str` here.
15+
let s = String::deserialize(de)?;
16+
humantime::parse_duration(&s).map_err(serde::de::Error::custom)
17+
}
18+
}

crates/core/src/config.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ impl CertificateAuthority {
136136
}
137137

138138
#[serde_with::serde_as]
139-
#[derive(serde::Deserialize, Default)]
139+
#[derive(Clone, serde::Deserialize, Default)]
140140
#[serde(rename_all = "kebab-case")]
141141
pub struct LogConfig {
142142
#[serde_as(as = "Option<serde_with::DisplayFromStr>")]

crates/standalone/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ openssl.workspace = true
4242
parse-size.workspace = true
4343
prometheus.workspace = true
4444
scopeguard.workspace = true
45+
serde.workspace = true
4546
serde_json.workspace = true
4647
sled.workspace = true
4748
socket2.workspace = true

0 commit comments

Comments
 (0)