diff --git a/crates/types/src/net/replicated_loglet.rs b/crates/types/src/net/replicated_loglet.rs index 3006ec336d..28b9bc6db8 100644 --- a/crates/types/src/net/replicated_loglet.rs +++ b/crates/types/src/net/replicated_loglet.rs @@ -20,6 +20,7 @@ use crate::logs::metadata::SegmentIndex; use crate::logs::{LogId, LogletOffset, Record, SequenceNumber, TailState}; use crate::net::define_rpc; use crate::replicated_loglet::ReplicatedLogletId; +use crate::storage::ArcVec; // ----- ReplicatedLoglet Sequencer API ----- define_rpc! { @@ -69,7 +70,7 @@ impl CommonResponseHeader { pub struct Append { #[serde(flatten)] pub header: CommonRequestHeader, - pub payloads: Vec, + pub payloads: ArcVec, } impl Append { diff --git a/crates/types/src/storage.rs b/crates/types/src/storage.rs index f97e5408b2..172a53f61e 100644 --- a/crates/types/src/storage.rs +++ b/crates/types/src/storage.rs @@ -8,13 +8,17 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. +use core::fmt; +use std::marker::PhantomData; use std::mem; +use std::ops::Deref; use std::sync::Arc; use bytes::{Buf, BufMut, Bytes, BytesMut}; use downcast_rs::{impl_downcast, DowncastSync}; use serde::de::{DeserializeOwned, Error as DeserializationError}; use serde::ser::Error as SerializationError; +use serde::ser::SerializeSeq; use serde::{Deserialize, Serialize}; use tracing::error; @@ -395,6 +399,126 @@ pub fn decode_from_flexbuffers( } } +/// [`ArcVec`] mainly used by `message` types to improve +/// cloning of messages. +/// +/// It can replace [`Vec`] most of the time in all structures +/// that need to be serialized over the wire. +/// +/// Internally it keeps the data inside an [`Arc<[T]>`] +#[derive(Debug)] +pub struct ArcVec { + inner: Arc<[T]>, +} + +impl Deref for ArcVec { + type Target = [T]; + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl serde::Serialize for ArcVec +where + T: serde::Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut seq = serializer.serialize_seq(Some(self.len()))?; + for elem in self.iter() { + seq.serialize_element(elem)?; + } + + seq.end() + } +} + +impl<'de, T> serde::Deserialize<'de> for ArcVec +where + T: serde::Deserialize<'de>, +{ + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_seq(ArcVecVisitor::default()) + } +} + +struct ArcVecVisitor { + _phantom: PhantomData, +} + +impl Default for ArcVecVisitor { + fn default() -> Self { + Self { + _phantom: PhantomData, + } + } +} + +impl<'de, T> serde::de::Visitor<'de> for ArcVecVisitor +where + T: serde::Deserialize<'de>, +{ + type Value = ArcVec; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "expecting an array") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + let mut vec: Vec = Vec::with_capacity(seq.size_hint().unwrap_or_default()); + while let Some(value) = seq.next_element()? { + vec.push(value); + } + + Ok(vec.into()) + } +} + +impl Clone for ArcVec { + fn clone(&self) -> Self { + Self { + inner: Arc::clone(&self.inner), + } + } +} + +impl From> for Arc<[T]> { + fn from(value: ArcVec) -> Self { + value.inner + } +} + +impl From> for Vec +where + T: Clone, +{ + fn from(value: ArcVec) -> Self { + Vec::from_iter(value.iter().cloned()) + } +} + +impl From> for ArcVec { + fn from(value: Vec) -> Self { + Self { + inner: value.into(), + } + } +} + +impl From> for ArcVec { + fn from(value: Arc<[T]>) -> Self { + Self { inner: value } + } +} + #[cfg(test)] mod tests { use bytes::Bytes;