Skip to content

Commit 1318e7e

Browse files
Centrilcoolreader18gefjon
authored
messages::serialize: take/put buffers from/into a SerializeBufferPool (#2823)
Co-authored-by: Noa <coolreader18@gmail.com> Co-authored-by: Phoebe Goldman <phoebe@clockworklabs.io> Co-authored-by: Phoebe Goldman <phoebe@goldman-tribe.org>
1 parent 27af02c commit 1318e7e

File tree

5 files changed

+215
-66
lines changed

5 files changed

+215
-66
lines changed

crates/client-api-messages/src/websocket.rs

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -859,25 +859,19 @@ pub fn decide_compression(len: usize, compression: Compression) -> Compression {
859859
}
860860
}
861861

862-
pub fn brotli_compress(bytes: &[u8], out: &mut Vec<u8>) {
863-
let reader = &mut &bytes[..];
864-
865-
// The default Brotli buffer size.
866-
const BUFFER_SIZE: usize = 4096;
862+
pub fn brotli_compress(bytes: &[u8], out: &mut impl io::Write) {
867863
// We are optimizing for compression speed,
868864
// so we choose the lowest (fastest) level of compression.
869865
// Experiments on internal workloads have shown compression ratios between 7:1 and 10:1
870866
// for large `SubscriptionUpdate` messages at this level.
871-
const COMPRESSION_LEVEL: u32 = 1;
872-
// The default value for an internal compression parameter.
873-
// See `BrotliEncoderParams` for more details.
874-
const LG_WIN: u32 = 22;
867+
const COMPRESSION_LEVEL: i32 = 1;
875868

876-
let mut encoder = brotli::CompressorReader::new(reader, BUFFER_SIZE, COMPRESSION_LEVEL, LG_WIN);
877-
878-
encoder
879-
.read_to_end(out)
880-
.expect("Failed to Brotli compress `SubscriptionUpdateMessage`");
869+
let params = brotli::enc::BrotliEncoderParams {
870+
quality: COMPRESSION_LEVEL,
871+
..<_>::default()
872+
};
873+
let reader = &mut &bytes[..];
874+
brotli::BrotliCompress(reader, out, &params).expect("should be able to BrotliCompress");
881875
}
882876

883877
pub fn brotli_decompress(bytes: &[u8]) -> Result<Vec<u8>, io::Error> {
@@ -886,10 +880,10 @@ pub fn brotli_decompress(bytes: &[u8]) -> Result<Vec<u8>, io::Error> {
886880
Ok(decompressed)
887881
}
888882

889-
pub fn gzip_compress(bytes: &[u8], out: &mut Vec<u8>) {
883+
pub fn gzip_compress(bytes: &[u8], out: &mut impl io::Write) {
890884
let mut encoder = flate2::write::GzEncoder::new(out, flate2::Compression::fast());
891885
encoder.write_all(bytes).unwrap();
892-
encoder.finish().expect("Failed to gzip compress `bytes`");
886+
encoder.finish().expect("should be able to gzip compress `bytes`");
893887
}
894888

895889
pub fn gzip_decompress(bytes: &[u8]) -> Result<Vec<u8>, io::Error> {

crates/client-api/src/routes/subscribe.rs

Lines changed: 68 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@ use futures::{Future, FutureExt, SinkExt, StreamExt};
1414
use http::{HeaderValue, StatusCode};
1515
use scopeguard::ScopeGuard;
1616
use serde::Deserialize;
17-
use spacetimedb::client::messages::{serialize, IdentityTokenMessage, SerializableMessage};
17+
use spacetimedb::client::messages::{serialize, IdentityTokenMessage, SerializableMessage, SerializeBuffer};
1818
use spacetimedb::client::{ClientActorId, ClientConfig, ClientConnection, DataMessage, MessageHandleError, Protocol};
19+
use spacetimedb::execution_context::WorkloadType;
1920
use spacetimedb::host::module_host::ClientConnectedError;
2021
use spacetimedb::host::NoSuchModule;
2122
use spacetimedb::util::also_poll;
2223
use spacetimedb::worker_metrics::WORKER_METRICS;
24+
use spacetimedb::Identity;
2325
use spacetimedb_client_api_messages::websocket::{self as ws_api, Compression};
2426
use spacetimedb_lib::connection_id::{ConnectionId, ConnectionIdForUrl};
2527
use std::time::Instant;
@@ -246,6 +248,7 @@ async fn ws_client_actor_inner(
246248
outgoing_queue_length_metric.sub(sendrx.len() as _);
247249
};
248250

251+
let mut msg_buffer = SerializeBuffer::new(client.config);
249252
loop {
250253
rx_buf.clear();
251254
enum Item {
@@ -299,36 +302,40 @@ async fn ws_client_actor_inner(
299302
log::info!("dropping {n} messages due to ws already being closed");
300303
log::debug!("dropped messages: {:?}", &rx_buf[..n]);
301304
} else {
302-
let send_all = async {
303-
for msg in rx_buf.drain(..n) {
304-
let workload = msg.workload();
305-
let num_rows = msg.num_rows();
306-
307-
let msg = datamsg_to_wsmsg(serialize(msg, client.config));
308-
309-
// These metrics should be updated together,
310-
// or not at all.
311-
if let (Some(workload), Some(num_rows)) = (workload, num_rows) {
312-
WORKER_METRICS
313-
.websocket_sent_num_rows
314-
.with_label_values(&addr, &workload)
315-
.observe(num_rows as f64);
316-
WORKER_METRICS
317-
.websocket_sent_msg_size
318-
.with_label_values(&addr, &workload)
319-
.observe(msg.len() as f64);
305+
let send_all = async {
306+
for msg in rx_buf.drain(..n) {
307+
let workload = msg.workload();
308+
let num_rows = msg.num_rows();
309+
310+
// Serialize the message, report metrics,
311+
// and keep a handle to the buffer.
312+
let (msg_alloc, msg_data) = serialize(msg_buffer, msg, client.config);
313+
report_ws_sent_metrics(&addr, workload, num_rows, &msg_data);
314+
315+
// Buffer the message without necessarily sending it.
316+
let res = ws.feed(datamsg_to_wsmsg(msg_data)).await;
317+
318+
// At this point,
319+
// the underlying allocation of `msg_data` should have a single referent
320+
// and this should be `msg_alloc`.
321+
// We can put this back into our pool.
322+
msg_buffer = msg_alloc.try_reclaim()
323+
.expect("should have a unique referent to `msg_alloc`");
324+
325+
if res.is_err() {
326+
return (res, msg_buffer);
327+
}
320328
}
321-
// feed() buffers the message, but does not necessarily send it
322-
ws.feed(msg).await?;
323-
}
324-
// now we flush all the messages to the socket
325-
ws.flush().await
326-
};
329+
// now we flush all the messages to the socket
330+
(ws.flush().await, msg_buffer)
331+
};
327332
// Flush the websocket while continuing to poll the `handle_queue`,
328333
// to avoid deadlocks or delays due to enqueued futures holding resources.
329334
let send_all = also_poll(send_all, make_progress(&mut current_message));
330335
let t1 = Instant::now();
331-
if let Err(error) = send_all.await {
336+
let (send_all_result, buf) = send_all.await;
337+
msg_buffer = buf;
338+
if let Err(error) = send_all_result {
332339
log::warn!("Websocket send error: {error}")
333340
}
334341
let time = t1.elapsed();
@@ -394,10 +401,22 @@ async fn ws_client_actor_inner(
394401
if let Err(e) = res {
395402
if let MessageHandleError::Execution(err) = e {
396403
log::error!("{err:#}");
397-
let msg = serialize(err, client.config);
398-
if let Err(error) = ws.send(datamsg_to_wsmsg(msg)).await {
404+
// Serialize the message and keep a handle to the buffer.
405+
let (msg_alloc, msg_data) = serialize(msg_buffer, err, client.config);
406+
407+
// Buffer the message without necessarily sending it.
408+
if let Err(error) = ws.send(datamsg_to_wsmsg(msg_data)).await {
399409
log::warn!("Websocket send error: {error}")
400410
}
411+
412+
// At this point,
413+
// the underlying allocation of `msg_data` should have a single referent
414+
// and this should be `msg_alloc`.
415+
// We can put this back into our pool.
416+
msg_buffer = msg_alloc
417+
.try_reclaim()
418+
.expect("should have a unique referent to `msg_alloc`");
419+
401420
continue;
402421
}
403422
log::debug!("Client caused error on text message: {}", e);
@@ -461,6 +480,27 @@ impl ClientMessage {
461480
}
462481
}
463482

483+
/// Report metrics on sent rows and message sizes to a websocket client.
484+
fn report_ws_sent_metrics(
485+
addr: &Identity,
486+
workload: Option<WorkloadType>,
487+
num_rows: Option<usize>,
488+
msg_ws: &DataMessage,
489+
) {
490+
// These metrics should be updated together,
491+
// or not at all.
492+
if let (Some(workload), Some(num_rows)) = (workload, num_rows) {
493+
WORKER_METRICS
494+
.websocket_sent_num_rows
495+
.with_label_values(addr, &workload)
496+
.observe(num_rows as f64);
497+
WORKER_METRICS
498+
.websocket_sent_msg_size
499+
.with_label_values(addr, &workload)
500+
.observe(msg_ws.len() as f64);
501+
}
502+
}
503+
464504
fn datamsg_to_wsmsg(msg: DataMessage) -> WsMessage {
465505
match msg {
466506
DataMessage::Text(text) => WsMessage::Text(bytestring_to_utf8bytes(text)),

crates/core/src/client/client_connection.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,17 +234,27 @@ impl From<Vec<u8>> for DataMessage {
234234
}
235235

236236
impl DataMessage {
237+
/// Returns the number of bytes this message consists of.
237238
pub fn len(&self) -> usize {
238239
match self {
239-
DataMessage::Text(s) => s.len(),
240-
DataMessage::Binary(b) => b.len(),
240+
Self::Text(s) => s.len(),
241+
Self::Binary(b) => b.len(),
241242
}
242243
}
243244

245+
/// Is the message empty?
244246
#[must_use]
245247
pub fn is_empty(&self) -> bool {
246248
self.len() == 0
247249
}
250+
251+
/// Returns a handle to the underlying allocation of the message without consuming it.
252+
pub fn allocation(&self) -> Bytes {
253+
match self {
254+
DataMessage::Text(alloc) => alloc.as_bytes().clone(),
255+
DataMessage::Binary(alloc) => alloc.clone(),
256+
}
257+
}
248258
}
249259

250260
// if a client racks up this many messages in the queue without ACK'ing

crates/core/src/client/messages.rs

Lines changed: 117 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ use crate::execution_context::WorkloadType;
33
use crate::host::module_host::{EventStatus, ModuleEvent};
44
use crate::host::ArgsTuple;
55
use crate::messages::websocket as ws;
6+
use bytes::{BufMut, Bytes, BytesMut};
7+
use bytestring::ByteString;
68
use derive_more::From;
79
use spacetimedb_client_api_messages::websocket::{
810
BsatnFormat, Compression, FormatSwitch, JsonFormat, OneOffTable, RowListLen, WebsocketFormat,
@@ -27,36 +29,131 @@ pub trait ToProtocol {
2729
pub(super) type SwitchedServerMessage = FormatSwitch<ws::ServerMessage<BsatnFormat>, ws::ServerMessage<JsonFormat>>;
2830
pub(super) type SwitchedDbUpdate = FormatSwitch<ws::DatabaseUpdate<BsatnFormat>, ws::DatabaseUpdate<JsonFormat>>;
2931

32+
/// The initial size of a `serialize` buffer.
33+
/// Currently 4k to align with the linux page size
34+
/// and this should be more than enough in the common case.
35+
const SERIALIZE_BUFFER_INIT_CAP: usize = 4096;
36+
37+
/// A buffer used by [`serialize`]
38+
pub struct SerializeBuffer {
39+
uncompressed: BytesMut,
40+
compressed: BytesMut,
41+
}
42+
43+
impl SerializeBuffer {
44+
pub fn new(config: ClientConfig) -> Self {
45+
let uncompressed_capacity = SERIALIZE_BUFFER_INIT_CAP;
46+
let compressed_capacity = if config.compression == Compression::None || config.protocol == Protocol::Text {
47+
0
48+
} else {
49+
SERIALIZE_BUFFER_INIT_CAP
50+
};
51+
Self {
52+
uncompressed: BytesMut::with_capacity(uncompressed_capacity),
53+
compressed: BytesMut::with_capacity(compressed_capacity),
54+
}
55+
}
56+
57+
/// Take the uncompressed message as the one to use.
58+
fn uncompressed(self) -> (InUseSerializeBuffer, Bytes) {
59+
let uncompressed = self.uncompressed.freeze();
60+
let in_use = InUseSerializeBuffer::Uncompressed {
61+
uncompressed: uncompressed.clone(),
62+
compressed: self.compressed,
63+
};
64+
(in_use, uncompressed)
65+
}
66+
67+
/// Write uncompressed data with a leading tag.
68+
fn write_with_tag<F>(&mut self, tag: u8, write: F) -> &[u8]
69+
where
70+
F: FnOnce(bytes::buf::Writer<&mut BytesMut>),
71+
{
72+
self.uncompressed.put_u8(tag);
73+
write((&mut self.uncompressed).writer());
74+
&self.uncompressed[1..]
75+
}
76+
77+
/// Compress the data from a `write_with_tag` call, and change the tag.
78+
fn compress_with_tag(
79+
self,
80+
tag: u8,
81+
write: impl FnOnce(&[u8], &mut bytes::buf::Writer<BytesMut>),
82+
) -> (InUseSerializeBuffer, Bytes) {
83+
let mut writer = self.compressed.writer();
84+
writer.get_mut().put_u8(tag);
85+
write(&self.uncompressed[1..], &mut writer);
86+
let compressed = writer.into_inner().freeze();
87+
let in_use = InUseSerializeBuffer::Compressed {
88+
uncompressed: self.uncompressed,
89+
compressed: compressed.clone(),
90+
};
91+
(in_use, compressed)
92+
}
93+
}
94+
95+
type BytesMutWriter<'a> = bytes::buf::Writer<&'a mut BytesMut>;
96+
97+
pub enum InUseSerializeBuffer {
98+
Uncompressed { uncompressed: Bytes, compressed: BytesMut },
99+
Compressed { uncompressed: BytesMut, compressed: Bytes },
100+
}
101+
102+
impl InUseSerializeBuffer {
103+
pub fn try_reclaim(self) -> Option<SerializeBuffer> {
104+
let (mut uncompressed, mut compressed) = match self {
105+
Self::Uncompressed {
106+
uncompressed,
107+
compressed,
108+
} => (uncompressed.try_into_mut().ok()?, compressed),
109+
Self::Compressed {
110+
uncompressed,
111+
compressed,
112+
} => (uncompressed, compressed.try_into_mut().ok()?),
113+
};
114+
uncompressed.clear();
115+
compressed.clear();
116+
Some(SerializeBuffer {
117+
uncompressed,
118+
compressed,
119+
})
120+
}
121+
}
122+
30123
/// Serialize `msg` into a [`DataMessage`] containing a [`ws::ServerMessage`].
31124
///
32125
/// If `protocol` is [`Protocol::Binary`],
33126
/// the message will be conditionally compressed by this method according to `compression`.
34-
pub fn serialize(msg: impl ToProtocol<Encoded = SwitchedServerMessage>, config: ClientConfig) -> DataMessage {
35-
// TODO(centril, perf): here we are allocating buffers only to throw them away eventually.
36-
// Consider pooling these allocations so that we reuse them.
127+
pub fn serialize(
128+
mut buffer: SerializeBuffer,
129+
msg: impl ToProtocol<Encoded = SwitchedServerMessage>,
130+
config: ClientConfig,
131+
) -> (InUseSerializeBuffer, DataMessage) {
37132
match msg.to_protocol(config.protocol) {
38-
FormatSwitch::Json(msg) => serde_json::to_string(&SerializeWrapper::new(msg)).unwrap().into(),
133+
FormatSwitch::Json(msg) => {
134+
let out: BytesMutWriter<'_> = (&mut buffer.uncompressed).writer();
135+
serde_json::to_writer(out, &SerializeWrapper::new(msg))
136+
.expect("should be able to json encode a `ServerMessage`");
137+
138+
let (in_use, out) = buffer.uncompressed();
139+
// SAFETY: `serde_json::to_writer` states that:
140+
// > "Serialization guarantees it only feeds valid UTF-8 sequences to the writer."
141+
let msg_json = unsafe { ByteString::from_bytes_unchecked(out) };
142+
(in_use, msg_json.into())
143+
}
39144
FormatSwitch::Bsatn(msg) => {
40145
// First write the tag so that we avoid shifting the entire message at the end.
41-
let mut msg_bytes = vec![SERVER_MSG_COMPRESSION_TAG_NONE];
42-
bsatn::to_writer(&mut msg_bytes, &msg).unwrap();
146+
let srv_msg = buffer.write_with_tag(SERVER_MSG_COMPRESSION_TAG_NONE, |w| {
147+
bsatn::to_writer(w.into_inner(), &msg).unwrap()
148+
});
43149

44150
// Conditionally compress the message.
45-
let srv_msg = &msg_bytes[1..];
46-
let msg_bytes = match ws::decide_compression(srv_msg.len(), config.compression) {
47-
Compression::None => msg_bytes,
48-
Compression::Brotli => {
49-
let mut out = vec![SERVER_MSG_COMPRESSION_TAG_BROTLI];
50-
ws::brotli_compress(srv_msg, &mut out);
51-
out
52-
}
53-
Compression::Gzip => {
54-
let mut out = vec![SERVER_MSG_COMPRESSION_TAG_GZIP];
55-
ws::gzip_compress(srv_msg, &mut out);
56-
out
57-
}
151+
let (in_use, msg_bytes) = match ws::decide_compression(srv_msg.len(), config.compression) {
152+
Compression::None => buffer.uncompressed(),
153+
Compression::Brotli => buffer.compress_with_tag(SERVER_MSG_COMPRESSION_TAG_BROTLI, ws::brotli_compress),
154+
Compression::Gzip => buffer.compress_with_tag(SERVER_MSG_COMPRESSION_TAG_GZIP, ws::gzip_compress),
58155
};
59-
msg_bytes.into()
156+
(in_use, msg_bytes.into())
60157
}
61158
}
62159
}

0 commit comments

Comments
 (0)