Skip to content

Commit dd699c4

Browse files
authored
Make our sats<->serde translation compatible with RON (#1738)
1 parent 49712bf commit dd699c4

File tree

6 files changed

+117
-104
lines changed

6 files changed

+117
-104
lines changed

Cargo.lock

Lines changed: 16 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ rayon = "1.8"
191191
rayon-core = "1.11.0"
192192
regex = "1"
193193
reqwest = { version = "0.12", features = ["stream", "json"] }
194+
ron = "0.8"
194195
rusqlite = { version = "0.29.0", features = ["bundled", "column_decltype"] }
195196
rust_decimal = { version = "1.29.1", features = ["db-tokio-postgres"] }
196197
rustc-demangle = "0.1.21"

crates/lib/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ rand.workspace = true
4949
bytes.workspace = true
5050
serde_json.workspace = true
5151
insta.workspace = true
52+
ron.workspace = true
5253

5354
# Also as dev-dependencies for use in _this_ crate's tests.
5455
proptest.workspace = true

crates/lib/tests/serde.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,21 @@ fn test_roundtrip() {
4444
assert_eq!(&original, &result);
4545
}
4646

47+
#[test]
48+
fn test_roundtrip_ron() {
49+
let original = Sample {
50+
identity: Identity::__dummy(),
51+
};
52+
53+
let s = value_serialize(&original);
54+
let result: Sample = spacetimedb_sats::de::Deserialize::deserialize(ValueDeserializer::new(s)).unwrap();
55+
assert_eq!(&original, &result);
56+
57+
let s = ron::to_string(&original).unwrap();
58+
let result: Sample = ron::from_str(&s).unwrap();
59+
assert_eq!(&original, &result);
60+
}
61+
4762
#[test]
4863
fn test_json_mappings() {
4964
let schema = tuple([

crates/sats/src/de/serde.rs

Lines changed: 77 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ use serde::de as serde;
66

77
/// Converts any [`serde::Deserializer`] to a SATS [`Deserializer`]
88
/// so that Serde's data formats can be reused.
9+
///
10+
/// In order for successful round-trip deserialization, the `serde::Deserializer`
11+
/// that this type wraps must support `deserialize_any()`.
912
pub struct SerdeDeserializer<D> {
1013
/// A deserialization data format in Serde.
1114
de: D,
@@ -46,19 +49,11 @@ impl<'de, D: serde::Deserializer<'de>> Deserializer<'de> for SerdeDeserializer<D
4649
type Error = SerdeError<D::Error>;
4750

4851
fn deserialize_product<V: super::ProductVisitor<'de>>(self, visitor: V) -> Result<V::Output, Self::Error> {
49-
self.de
50-
.deserialize_struct("", &[], TupleVisitor { visitor })
51-
.map_err(SerdeError)
52+
self.de.deserialize_any(TupleVisitor { visitor }).map_err(SerdeError)
5253
}
5354

5455
fn deserialize_sum<V: super::SumVisitor<'de>>(self, visitor: V) -> Result<V::Output, Self::Error> {
55-
if visitor.is_option() && self.de.is_human_readable() {
56-
self.de.deserialize_any(OptionVisitor { visitor }).map_err(SerdeError)
57-
} else {
58-
self.de
59-
.deserialize_enum("", &[], EnumVisitor { visitor })
60-
.map_err(SerdeError)
61-
}
56+
self.de.deserialize_any(EnumVisitor { visitor }).map_err(SerdeError)
6257
}
6358

6459
fn deserialize_bool(self) -> Result<bool, Self::Error> {
@@ -267,71 +262,6 @@ impl<'de, A: serde::SeqAccess<'de>> super::SeqProductAccess<'de> for SeqTupleAcc
267262
}
268263
}
269264

270-
/// Converts a `SumVisitor` into a `serde::Visitor` for deserializing option.
271-
struct OptionVisitor<V> {
272-
/// The visitor to convert.
273-
visitor: V,
274-
}
275-
276-
impl<'de, V: super::SumVisitor<'de>> serde::Visitor<'de> for OptionVisitor<V> {
277-
type Value = V::Output;
278-
279-
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
280-
f.write_str("option")
281-
}
282-
283-
fn visit_map<A: serde::MapAccess<'de>>(self, map: A) -> Result<Self::Value, A::Error> {
284-
self.visitor.visit_sum(SomeAccess(map)).map_err(unwrap_error)
285-
}
286-
287-
fn visit_unit<E: serde::Error>(self) -> Result<Self::Value, E> {
288-
self.visitor.visit_sum(NoneAccess(PhantomData)).map_err(unwrap_error)
289-
}
290-
}
291-
292-
/// Deserializes `some` variant of an optional value.
293-
/// Converts Serde's map deserialization to SATS.
294-
struct SomeAccess<A>(A);
295-
296-
impl<'de, A: serde::MapAccess<'de>> super::SumAccess<'de> for SomeAccess<A> {
297-
type Error = SerdeError<A::Error>;
298-
type Variant = Self;
299-
300-
fn variant<V: super::VariantVisitor>(mut self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error> {
301-
self.0
302-
.next_key_seed(VariantVisitor { visitor })
303-
.and_then(|x| match x {
304-
Some(x) => Ok((x, self)),
305-
None => Err(serde::Error::custom("expected variant name")),
306-
})
307-
.map_err(SerdeError)
308-
}
309-
}
310-
impl<'de, A: serde::MapAccess<'de>> super::VariantAccess<'de> for SomeAccess<A> {
311-
type Error = SerdeError<A::Error>;
312-
313-
fn deserialize_seed<T: super::DeserializeSeed<'de>>(mut self, seed: T) -> Result<T::Output, Self::Error> {
314-
let ret = self.0.next_value_seed(SeedWrapper(seed)).map_err(SerdeError)?;
315-
self.0.next_key_seed(NothingVisitor).map_err(SerdeError)?;
316-
Ok(ret)
317-
}
318-
}
319-
320-
/// Deserializes nothing, producing `!` effectively.
321-
struct NothingVisitor;
322-
impl<'de> serde::DeserializeSeed<'de> for NothingVisitor {
323-
type Value = std::convert::Infallible;
324-
fn deserialize<D: serde::Deserializer<'de>>(self, deserializer: D) -> Result<Self::Value, D::Error> {
325-
deserializer.deserialize_identifier(self)
326-
}
327-
}
328-
impl serde::Visitor<'_> for NothingVisitor {
329-
type Value = std::convert::Infallible;
330-
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
331-
f.write_str("nothing")
332-
}
333-
}
334-
335265
/// Deserializes `none` variant of an optional value.
336266
struct NoneAccess<E>(PhantomData<E>);
337267
impl<E: super::Error> super::SumAccess<'_> for NoneAccess<E> {
@@ -364,29 +294,32 @@ impl<'de, V: super::SumVisitor<'de>> serde::Visitor<'de> for EnumVisitor<V> {
364294
type Value = V::Output;
365295

366296
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
367-
f.write_str("enum")
297+
match self.visitor.sum_name() {
298+
Some(name) => write!(f, "sum type {name}"),
299+
None => f.write_str("sum type"),
300+
}
368301
}
369302

370-
fn visit_enum<A: serde::EnumAccess<'de>>(self, access: A) -> Result<Self::Value, A::Error> {
303+
fn visit_map<A>(self, access: A) -> Result<Self::Value, A::Error>
304+
where
305+
A: serde::MapAccess<'de>,
306+
{
371307
self.visitor.visit_sum(EnumAccess { access }).map_err(unwrap_error)
372308
}
373-
}
374309

375-
/// Converts Serde's `EnumAccess` to SATS `SumAccess`.
376-
struct EnumAccess<A> {
377-
/// The Serde `EnumAccess`.
378-
access: A,
379-
}
380-
381-
impl<'de, A: serde::EnumAccess<'de>> super::SumAccess<'de> for EnumAccess<A> {
382-
type Error = SerdeError<A::Error>;
383-
type Variant = VariantAccess<A::Variant>;
310+
fn visit_seq<A>(self, access: A) -> Result<Self::Value, A::Error>
311+
where
312+
A: serde::SeqAccess<'de>,
313+
{
314+
self.visitor.visit_sum(SeqEnumAccess { access }).map_err(unwrap_error)
315+
}
384316

385-
fn variant<V: super::VariantVisitor>(self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error> {
386-
self.access
387-
.variant_seed(VariantVisitor { visitor })
388-
.map(|(variant, access)| (variant, VariantAccess { access }))
389-
.map_err(SerdeError)
317+
fn visit_unit<E: serde::Error>(self) -> Result<Self::Value, E> {
318+
if self.visitor.is_option() {
319+
self.visitor.visit_sum(NoneAccess(PhantomData)).map_err(unwrap_error)
320+
} else {
321+
Err(E::invalid_type(serde::Unexpected::Unit, &self))
322+
}
390323
}
391324
}
392325

@@ -400,7 +333,7 @@ impl<'de, V: super::VariantVisitor> serde::DeserializeSeed<'de> for VariantVisit
400333
type Value = V::Output;
401334

402335
fn deserialize<D: serde::Deserializer<'de>>(self, deserializer: D) -> Result<Self::Value, D::Error> {
403-
deserializer.deserialize_identifier(self)
336+
deserializer.deserialize_any(self)
404337
}
405338
}
406339

@@ -430,17 +363,62 @@ impl<V: super::VariantVisitor> serde::Visitor<'_> for VariantVisitor<V> {
430363
}
431364
}
432365

433-
/// Deserializes the data of a variant using Serde's `serde::VariantAccess` translating this to SATS.
434-
struct VariantAccess<A> {
435-
// Implements `serde::VariantAccess`.
366+
/// Converts Serde's `EnumAccess` to SATS `SumAccess`.
367+
struct EnumAccess<A> {
368+
/// The Serde `EnumAccess`.
436369
access: A,
437370
}
438371

439-
impl<'de, A: serde::VariantAccess<'de>> super::VariantAccess<'de> for VariantAccess<A> {
372+
impl<'de, A: serde::MapAccess<'de>> super::SumAccess<'de> for EnumAccess<A> {
440373
type Error = SerdeError<A::Error>;
374+
type Variant = Self;
441375

442-
fn deserialize_seed<T: super::DeserializeSeed<'de>>(self, seed: T) -> Result<T::Output, Self::Error> {
443-
self.access.newtype_variant_seed(SeedWrapper(seed)).map_err(SerdeError)
376+
fn variant<V: super::VariantVisitor>(mut self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error> {
377+
let errmsg = "expected map representing sum type to have exactly one field";
378+
let key = self
379+
.access
380+
.next_key_seed(VariantVisitor { visitor })
381+
.map_err(SerdeError)?
382+
.ok_or_else(|| SerdeError(serde::Error::custom(errmsg)))?;
383+
Ok((key, self))
384+
}
385+
}
386+
387+
impl<'de, A: serde::MapAccess<'de>> super::VariantAccess<'de> for EnumAccess<A> {
388+
type Error = SerdeError<A::Error>;
389+
390+
fn deserialize_seed<T: super::DeserializeSeed<'de>>(mut self, seed: T) -> Result<T::Output, Self::Error> {
391+
self.access.next_value_seed(SeedWrapper(seed)).map_err(SerdeError)
392+
}
393+
}
394+
395+
struct SeqEnumAccess<A> {
396+
access: A,
397+
}
398+
399+
const SEQ_ENUM_ERR: &str = "expected seq representing sum type to have exactly two fields";
400+
impl<'de, A: serde::SeqAccess<'de>> super::SumAccess<'de> for SeqEnumAccess<A> {
401+
type Error = SerdeError<A::Error>;
402+
type Variant = Self;
403+
404+
fn variant<V: super::VariantVisitor>(mut self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error> {
405+
let key = self
406+
.access
407+
.next_element_seed(VariantVisitor { visitor })
408+
.map_err(SerdeError)?
409+
.ok_or_else(|| SerdeError(serde::Error::custom(SEQ_ENUM_ERR)))?;
410+
Ok((key, self))
411+
}
412+
}
413+
414+
impl<'de, A: serde::SeqAccess<'de>> super::VariantAccess<'de> for SeqEnumAccess<A> {
415+
type Error = SerdeError<A::Error>;
416+
417+
fn deserialize_seed<T: super::DeserializeSeed<'de>>(mut self, seed: T) -> Result<T::Output, Self::Error> {
418+
self.access
419+
.next_element_seed(SeedWrapper(seed))
420+
.map_err(SerdeError)?
421+
.ok_or_else(|| SerdeError(serde::Error::custom(SEQ_ENUM_ERR)))
444422
}
445423
}
446424

crates/sats/src/ser/serde.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,16 +118,18 @@ impl<S: serde::Serializer> Serializer for SerdeSerializer<S> {
118118
value: &T,
119119
) -> Result<Self::Ok, Self::Error> {
120120
// can't use serialize_variant cause we're too dynamic :(
121-
use serde::SerializeMap;
122-
let mut map = self.ser.serialize_map(Some(1)).map_err(SerdeError)?;
121+
use serde::{SerializeMap, SerializeTuple};
123122
let value = SerializeWrapper::from_ref(value);
124123
if let Some(name) = name {
124+
let mut map = self.ser.serialize_map(Some(1)).map_err(SerdeError)?;
125125
map.serialize_entry(name, value).map_err(SerdeError)?;
126+
map.end().map_err(SerdeError)
126127
} else {
127-
// FIXME: this probably wouldn't decode if you ran it back through
128-
map.serialize_entry(&tag, value).map_err(SerdeError)?;
128+
let mut seq = self.ser.serialize_tuple(2).map_err(SerdeError)?;
129+
seq.serialize_element(&tag).map_err(SerdeError)?;
130+
seq.serialize_element(value).map_err(SerdeError)?;
131+
seq.end().map_err(SerdeError)
129132
}
130-
map.end().map_err(SerdeError)
131133
}
132134

133135
unsafe fn serialize_bsatn(self, ty: &crate::AlgebraicType, bsatn: &[u8]) -> Result<Self::Ok, Self::Error> {

0 commit comments

Comments
 (0)