@@ -44,6 +44,7 @@ use tokio::time::{sleep_until, timeout};
44
44
use tokio_tungstenite:: tungstenite:: Utf8Bytes ;
45
45
46
46
use crate :: auth:: SpacetimeAuth ;
47
+ use crate :: util:: serde:: humantime_duration;
47
48
use crate :: util:: websocket:: {
48
49
CloseCode , CloseFrame , Message as WsMessage , WebSocketConfig , WebSocketStream , WebSocketUpgrade , WsError ,
49
50
} ;
@@ -55,6 +56,16 @@ pub const TEXT_PROTOCOL: HeaderValue = HeaderValue::from_static(ws_api::TEXT_PRO
55
56
#[ allow( clippy:: declare_interior_mutable_const) ]
56
57
pub const BIN_PROTOCOL : HeaderValue = HeaderValue :: from_static ( ws_api:: BIN_PROTOCOL ) ;
57
58
59
+ pub trait HasWebSocketOptions {
60
+ fn websocket_options ( & self ) -> WebSocketOptions ;
61
+ }
62
+
63
+ impl < T : HasWebSocketOptions > HasWebSocketOptions for Arc < T > {
64
+ fn websocket_options ( & self ) -> WebSocketOptions {
65
+ ( * * self ) . websocket_options ( )
66
+ }
67
+ }
68
+
58
69
#[ derive( Deserialize ) ]
59
70
pub struct SubscribeParams {
60
71
pub name_or_identity : NameOrIdentity ,
@@ -88,7 +99,7 @@ pub async fn handle_websocket<S>(
88
99
ws : WebSocketUpgrade ,
89
100
) -> axum:: response:: Result < impl IntoResponse >
90
101
where
91
- S : NodeDelegate + ControlStateDelegate ,
102
+ S : NodeDelegate + ControlStateDelegate + HasWebSocketOptions ,
92
103
{
93
104
if connection_id. is_some ( ) {
94
105
// TODO: Bump this up to `log::warn!` after removing the client SDKs' uses of that parameter.
@@ -146,6 +157,7 @@ where
146
157
. max_message_size ( Some ( 0x2000000 ) )
147
158
. max_frame_size ( None )
148
159
. accept_unmasked_frames ( false ) ;
160
+ let ws_opts = ctx. websocket_options ( ) ;
149
161
150
162
tokio:: spawn ( async move {
151
163
let ws = match ws_upgrade. upgrade ( ws_config) . await {
@@ -163,7 +175,7 @@ where
163
175
None => log:: debug!( "New client connected from unknown ip" ) ,
164
176
}
165
177
166
- let actor = |client, sendrx| ws_client_actor ( client, ws, sendrx) ;
178
+ let actor = |client, sendrx| ws_client_actor ( ws_opts , client, ws, sendrx) ;
167
179
let client = match ClientConnection :: spawn ( client_id, client_config, leader. replica_id , module_rx, actor) . await
168
180
{
169
181
Ok ( s) => s,
@@ -198,13 +210,13 @@ where
198
210
struct ActorState {
199
211
pub client_id : ClientActorId ,
200
212
pub database : Identity ,
201
- config : ActorConfig ,
213
+ config : WebSocketOptions ,
202
214
closed : AtomicBool ,
203
215
got_pong : AtomicBool ,
204
216
}
205
217
206
218
impl ActorState {
207
- pub fn new ( database : Identity , client_id : ClientActorId , config : ActorConfig ) -> Self {
219
+ pub fn new ( database : Identity , client_id : ClientActorId , config : WebSocketOptions ) -> Self {
208
220
Self {
209
221
database,
210
222
client_id,
@@ -235,14 +247,19 @@ impl ActorState {
235
247
}
236
248
}
237
249
238
- struct ActorConfig {
250
+ /// Configuration for WebSocket connections.
251
+ #[ derive( Clone , Copy , Debug , PartialEq , serde:: Serialize , serde:: Deserialize ) ]
252
+ #[ serde( rename_all = "kebab-case" ) ]
253
+ pub struct WebSocketOptions {
239
254
/// Interval at which to send `Ping` frames.
240
255
///
241
256
/// We use pings for connection keep-alive.
242
257
/// Value must be smaller than `idle_timeout`.
243
258
///
244
259
/// Default: 15s
245
- ping_interval : Duration ,
260
+ #[ serde( with = "humantime_duration" ) ]
261
+ #[ serde( default = "WebSocketOptions::default_ping_interval" ) ]
262
+ pub ping_interval : Duration ,
246
263
/// Amount of time after which an idle connection is closed.
247
264
///
248
265
/// A connection is considered idle if no data is received nor sent.
@@ -251,47 +268,80 @@ struct ActorConfig {
251
268
/// Value must be greater than `ping_interval`.
252
269
///
253
270
/// Default: 30s
254
- idle_timeout : Duration ,
271
+ #[ serde( with = "humantime_duration" ) ]
272
+ #[ serde( default = "WebSocketOptions::default_idle_timeout" ) ]
273
+ pub idle_timeout : Duration ,
255
274
/// For how long to keep draining the incoming messages until a client close
256
275
/// is received.
257
276
///
258
277
/// Default: 250ms
259
- close_handshake_timeout : Duration ,
278
+ #[ serde( with = "humantime_duration" ) ]
279
+ #[ serde( default = "WebSocketOptions::default_close_handshake_timeout" ) ]
280
+ pub close_handshake_timeout : Duration ,
260
281
/// Maximum number of messages to queue for processing.
261
282
///
262
283
/// If this number is exceeded, the client is disconnected.
263
284
///
264
285
/// Default: 2048
265
- incoming_queue_length : NonZeroUsize ,
286
+ #[ serde( default = "WebSocketOptions::default_incoming_queue_length" ) ]
287
+ pub incoming_queue_length : NonZeroUsize ,
266
288
}
267
289
268
- impl Default for ActorConfig {
290
+ impl Default for WebSocketOptions {
269
291
fn default ( ) -> Self {
270
- Self {
271
- ping_interval : Duration :: from_secs ( 15 ) ,
272
- idle_timeout : Duration :: from_secs ( 30 ) ,
273
- close_handshake_timeout : Duration :: from_millis ( 250 ) ,
274
- incoming_queue_length :
275
- // SAFETY: 2048 > 0, qed
276
- unsafe { NonZeroUsize :: new_unchecked ( 2048 ) }
277
- }
292
+ Self :: DEFAULT
278
293
}
279
294
}
280
295
281
- async fn ws_client_actor ( client : ClientConnection , ws : WebSocketStream , sendrx : MeteredReceiver < SerializableMessage > ) {
296
+ impl WebSocketOptions {
297
+ const DEFAULT_PING_INTERVAL : Duration = Duration :: from_secs ( 15 ) ;
298
+ const DEFAULT_IDLE_TIMEOUT : Duration = Duration :: from_secs ( 30 ) ;
299
+ const DEFAULT_CLOSE_HANDSHAKE_TIMEOUT : Duration = Duration :: from_millis ( 250 ) ;
300
+ const DEFAULT_INCOMING_QUEUE_LENGTH : NonZeroUsize = NonZeroUsize :: new ( 2048 ) . expect ( "2048 > 0, qed" ) ;
301
+
302
+ const DEFAULT : Self = Self {
303
+ ping_interval : Self :: DEFAULT_PING_INTERVAL ,
304
+ idle_timeout : Self :: DEFAULT_IDLE_TIMEOUT ,
305
+ close_handshake_timeout : Self :: DEFAULT_CLOSE_HANDSHAKE_TIMEOUT ,
306
+ incoming_queue_length : Self :: DEFAULT_INCOMING_QUEUE_LENGTH ,
307
+ } ;
308
+
309
+ const fn default_ping_interval ( ) -> Duration {
310
+ Self :: DEFAULT_PING_INTERVAL
311
+ }
312
+
313
+ const fn default_idle_timeout ( ) -> Duration {
314
+ Self :: DEFAULT_IDLE_TIMEOUT
315
+ }
316
+
317
+ const fn default_close_handshake_timeout ( ) -> Duration {
318
+ Self :: DEFAULT_CLOSE_HANDSHAKE_TIMEOUT
319
+ }
320
+
321
+ const fn default_incoming_queue_length ( ) -> NonZeroUsize {
322
+ Self :: DEFAULT_INCOMING_QUEUE_LENGTH
323
+ }
324
+ }
325
+
326
+ async fn ws_client_actor (
327
+ options : WebSocketOptions ,
328
+ client : ClientConnection ,
329
+ ws : WebSocketStream ,
330
+ sendrx : MeteredReceiver < SerializableMessage > ,
331
+ ) {
282
332
// ensure that even if this task gets cancelled, we always cleanup the connection
283
333
let mut client = scopeguard:: guard ( client, |client| {
284
334
tokio:: spawn ( client. disconnect ( ) ) ;
285
335
} ) ;
286
336
287
- ws_client_actor_inner ( & mut client, < _ > :: default ( ) , ws, sendrx) . await ;
337
+ ws_client_actor_inner ( & mut client, options , ws, sendrx) . await ;
288
338
289
339
ScopeGuard :: into_inner ( client) . disconnect ( ) . await ;
290
340
}
291
341
292
342
async fn ws_client_actor_inner (
293
343
client : & mut ClientConnection ,
294
- config : ActorConfig ,
344
+ config : WebSocketOptions ,
295
345
ws : WebSocketStream ,
296
346
sendrx : MeteredReceiver < SerializableMessage > ,
297
347
) {
@@ -1160,7 +1210,7 @@ mod tests {
1160
1210
dummy_actor_state_with_config ( <_ >:: default ( ) )
1161
1211
}
1162
1212
1163
- fn dummy_actor_state_with_config ( config : ActorConfig ) -> ActorState {
1213
+ fn dummy_actor_state_with_config ( config : WebSocketOptions ) -> ActorState {
1164
1214
ActorState :: new ( Identity :: ZERO , dummy_client_id ( ) , config)
1165
1215
}
1166
1216
@@ -1482,7 +1532,7 @@ mod tests {
1482
1532
1483
1533
#[ tokio:: test]
1484
1534
async fn main_loop_terminates_on_idle_timeout ( ) {
1485
- let state = Arc :: new ( dummy_actor_state_with_config ( ActorConfig {
1535
+ let state = Arc :: new ( dummy_actor_state_with_config ( WebSocketOptions {
1486
1536
idle_timeout : Duration :: from_millis ( 10 ) ,
1487
1537
..<_ >:: default ( )
1488
1538
} ) ) ;
@@ -1520,7 +1570,7 @@ mod tests {
1520
1570
1521
1571
#[ tokio:: test]
1522
1572
async fn main_loop_keepalive_keeps_alive ( ) {
1523
- let state = Arc :: new ( dummy_actor_state_with_config ( ActorConfig {
1573
+ let state = Arc :: new ( dummy_actor_state_with_config ( WebSocketOptions {
1524
1574
ping_interval : Duration :: from_millis ( 5 ) ,
1525
1575
idle_timeout : Duration :: from_millis ( 10 ) ,
1526
1576
..<_ >:: default ( )
@@ -1616,7 +1666,7 @@ mod tests {
1616
1666
1617
1667
#[ tokio:: test]
1618
1668
async fn recv_queue_sends_close_when_at_capacity ( ) {
1619
- let state = Arc :: new ( dummy_actor_state_with_config ( ActorConfig {
1669
+ let state = Arc :: new ( dummy_actor_state_with_config ( WebSocketOptions {
1620
1670
incoming_queue_length : 10 . try_into ( ) . unwrap ( ) ,
1621
1671
..<_ >:: default ( )
1622
1672
} ) ) ;
@@ -1632,7 +1682,7 @@ mod tests {
1632
1682
1633
1683
#[ tokio:: test]
1634
1684
async fn recv_queue_closes_state_if_sender_gone ( ) {
1635
- let state = Arc :: new ( dummy_actor_state_with_config ( ActorConfig {
1685
+ let state = Arc :: new ( dummy_actor_state_with_config ( WebSocketOptions {
1636
1686
incoming_queue_length : 10 . try_into ( ) . unwrap ( ) ,
1637
1687
..<_ >:: default ( )
1638
1688
} ) ) ;
@@ -1695,4 +1745,27 @@ mod tests {
1695
1745
Poll :: Ready ( Ok ( ( ) ) )
1696
1746
}
1697
1747
}
1748
+
1749
+ #[ test]
1750
+ fn options_toml_roundtrip ( ) {
1751
+ let options = WebSocketOptions :: default ( ) ;
1752
+ let toml = toml:: to_string ( & options) . unwrap ( ) ;
1753
+ assert_eq ! ( options, toml:: from_str:: <WebSocketOptions >( & toml) . unwrap( ) ) ;
1754
+ }
1755
+
1756
+ #[ test]
1757
+ fn options_from_partial_toml ( ) {
1758
+ let toml = r#"
1759
+ ping-interval = "53s"
1760
+ idle-timeout = "1m 3s"
1761
+ "# ;
1762
+
1763
+ let expected = WebSocketOptions {
1764
+ ping_interval : Duration :: from_secs ( 53 ) ,
1765
+ idle_timeout : Duration :: from_secs ( 63 ) ,
1766
+ ..<_ >:: default ( )
1767
+ } ;
1768
+
1769
+ assert_eq ! ( expected, toml:: from_str( toml) . unwrap( ) ) ;
1770
+ }
1698
1771
}
0 commit comments