Skip to content

Commit 22fa27e

Browse files
authored
Fix roundtrip deserialization of durations (#233)
1 parent da73a1d commit 22fa27e

File tree

6 files changed

+221
-24
lines changed

6 files changed

+221
-24
lines changed

lib/src/types/duration.rs

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ use neo4rs_macros::BoltStruct;
44
#[derive(Debug, PartialEq, Eq, Clone, BoltStruct)]
55
#[signature(0xB4, 0x45)]
66
pub struct BoltDuration {
7-
months: BoltInteger,
8-
days: BoltInteger,
9-
seconds: BoltInteger,
10-
nanoseconds: BoltInteger,
7+
pub(crate) months: BoltInteger,
8+
pub(crate) days: BoltInteger,
9+
pub(crate) seconds: BoltInteger,
10+
pub(crate) nanoseconds: BoltInteger,
1111
}
1212

1313
impl BoltDuration {
@@ -31,10 +31,6 @@ impl BoltDuration {
3131
.saturating_add(self.days.value.saturating_mul(24 * 3600))
3232
.saturating_add(self.months.value.saturating_mul(2_629_800))
3333
}
34-
35-
pub(crate) fn nanoseconds(&self) -> i64 {
36-
self.nanoseconds.value
37-
}
3834
}
3935

4036
impl From<std::time::Duration> for BoltDuration {
@@ -53,8 +49,7 @@ impl From<std::time::Duration> for BoltDuration {
5349
impl From<BoltDuration> for std::time::Duration {
5450
fn from(value: BoltDuration) -> Self {
5551
//TODO: clarify month issue
56-
let seconds =
57-
value.seconds.value + (value.days.value * 24 * 3600) + (value.months.value * 2_629_800);
52+
let seconds = value.seconds();
5853
std::time::Duration::new(seconds as u64, value.nanoseconds.value as u32)
5954
}
6055
}

lib/src/types/serde/date_time.rs

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@ use core::fmt;
22
use std::{iter::Peekable, marker::PhantomData};
33

44
use serde::de::{
5-
value::{BorrowedStrDeserializer, MapDeserializer, SeqDeserializer},
5+
value::{BorrowedStrDeserializer, MapDeserializer},
66
DeserializeSeed, Error, IntoDeserializer, MapAccess, SeqAccess, Visitor,
77
};
88

9-
use crate::types::{serde::builder::SetOnce, BoltLocalDateTime, BoltString};
109
use crate::{
11-
types::{BoltDateTime, BoltDateTimeZoneId, BoltDuration, BoltInteger},
10+
types::{
11+
serde::builder::SetOnce, BoltDateTime, BoltDateTimeZoneId, BoltInteger, BoltLocalDateTime,
12+
BoltString,
13+
},
1214
DeError,
1315
};
1416

@@ -57,12 +59,6 @@ impl BoltDateTimeZoneId {
5759
}
5860
}
5961

60-
impl BoltDuration {
61-
pub(crate) fn seq_access(&self) -> impl SeqAccess<'_, Error = DeError> {
62-
SeqDeserializer::new([self.seconds(), self.nanoseconds()].into_iter())
63-
}
64-
}
65-
6662
struct BoltDateTimeZoneIdAccess<'a, const N: usize>(
6763
&'a BoltDateTimeZoneId,
6864
Peekable<<[Fields; N] as IntoIterator>::IntoIter>,

lib/src/types/serde/duration.rs

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
use core::fmt;
2+
3+
use serde::de::{value::SeqDeserializer, Error, MapAccess, SeqAccess, Visitor};
4+
5+
use crate::{
6+
types::{serde::builder::SetOnce, BoltDuration, BoltInteger},
7+
DeError,
8+
};
9+
10+
crate::cenum!(Fields {
11+
Months,
12+
Days,
13+
Seconds,
14+
NanoSeconds,
15+
});
16+
17+
impl BoltDuration {
18+
pub(crate) fn seq_access_bolt(&self) -> impl SeqAccess<'_, Error = DeError> {
19+
SeqDeserializer::new(
20+
[
21+
self.months.value,
22+
self.days.value,
23+
self.seconds.value,
24+
self.nanoseconds.value,
25+
]
26+
.into_iter(),
27+
)
28+
}
29+
pub(crate) fn seq_access_external(&self) -> impl SeqAccess<'_, Error = DeError> {
30+
SeqDeserializer::new([self.seconds(), self.nanoseconds.value].into_iter())
31+
}
32+
}
33+
34+
pub struct BoltDurationVisitor;
35+
36+
impl<'de> Visitor<'de> for BoltDurationVisitor {
37+
type Value = BoltDuration;
38+
39+
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
40+
formatter.write_str("BoltDuration struct")
41+
}
42+
43+
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
44+
where
45+
A: MapAccess<'de>,
46+
{
47+
let mut builder = DurationBuilder::default();
48+
49+
while let Some(key) = map.next_key::<Fields>()? {
50+
match key {
51+
Fields::Months => builder.months(|| map.next_value())?,
52+
Fields::Days => builder.days(|| map.next_value())?,
53+
Fields::Seconds => builder.seconds(|| map.next_value())?,
54+
Fields::NanoSeconds => builder.nanoseconds(|| map.next_value())?,
55+
}
56+
}
57+
58+
builder.build()
59+
}
60+
61+
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
62+
where
63+
A: SeqAccess<'de>,
64+
{
65+
const FIELDS: [Fields; 4] = [
66+
Fields::Months,
67+
Fields::Days,
68+
Fields::Seconds,
69+
Fields::NanoSeconds,
70+
];
71+
72+
let mut require_next = |field| {
73+
seq.next_element()
74+
.and_then(|value| value.ok_or_else(|| Error::missing_field(field)))
75+
};
76+
77+
let mut builder = DurationBuilder::default();
78+
79+
for field in FIELDS {
80+
match field {
81+
Fields::Months => builder.months(|| require_next("months"))?,
82+
Fields::Days => builder.days(|| require_next("days"))?,
83+
Fields::Seconds => builder.seconds(|| require_next("seconds"))?,
84+
Fields::NanoSeconds => builder.nanoseconds(|| require_next("nanoseconds"))?,
85+
}
86+
}
87+
88+
if seq.next_element::<serde::de::IgnoredAny>()?.is_some() {
89+
return Err(Error::invalid_length(0, &"4"));
90+
}
91+
92+
builder.build()
93+
}
94+
}
95+
96+
#[derive(Default)]
97+
pub(crate) struct DurationBuilder {
98+
pub(crate) months: SetOnce<BoltInteger>,
99+
pub(crate) days: SetOnce<BoltInteger>,
100+
pub(crate) seconds: SetOnce<BoltInteger>,
101+
pub(crate) nanoseconds: SetOnce<BoltInteger>,
102+
}
103+
104+
impl DurationBuilder {
105+
fn months<E: Error>(&mut self, f: impl FnOnce() -> Result<BoltInteger, E>) -> Result<(), E> {
106+
self.months
107+
.try_insert_with(f)
108+
.map_or_else(|_| Err(Error::duplicate_field("months")), |_| Ok(()))
109+
}
110+
111+
fn days<E: Error>(&mut self, f: impl FnOnce() -> Result<BoltInteger, E>) -> Result<(), E> {
112+
self.days
113+
.try_insert_with(f)
114+
.map_or_else(|_| Err(Error::duplicate_field("days")), |_| Ok(()))
115+
}
116+
117+
fn seconds<E: Error>(&mut self, f: impl FnOnce() -> Result<BoltInteger, E>) -> Result<(), E> {
118+
self.seconds
119+
.try_insert_with(f)
120+
.map_or_else(|_| Err(Error::duplicate_field("seconds")), |_| Ok(()))
121+
}
122+
123+
fn nanoseconds<E: Error>(
124+
&mut self,
125+
f: impl FnOnce() -> Result<BoltInteger, E>,
126+
) -> Result<(), E> {
127+
self.nanoseconds
128+
.try_insert_with(f)
129+
.map_or_else(|_| Err(Error::duplicate_field("nanoseconds")), |_| Ok(()))
130+
}
131+
132+
fn build<E: Error>(mut self: DurationBuilder) -> Result<BoltDuration, E> {
133+
Ok(BoltDuration {
134+
months: self
135+
.months
136+
.take()
137+
.ok_or_else(|| Error::missing_field("months"))?,
138+
days: self
139+
.days
140+
.take()
141+
.ok_or_else(|| Error::missing_field("days"))?,
142+
seconds: self
143+
.seconds
144+
.take()
145+
.ok_or_else(|| Error::missing_field("seconds"))?,
146+
nanoseconds: self
147+
.nanoseconds
148+
.take()
149+
.ok_or_else(|| Error::missing_field("nanoseconds"))?,
150+
})
151+
}
152+
}

lib/src/types/serde/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ mod builder;
1010
mod cenum;
1111
mod date_time;
1212
mod de;
13+
mod duration;
1314
mod element;
1415
mod error;
1516
mod kind;

lib/src/types/serde/typ.rs

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use crate::{
22
types::{
33
serde::{
44
date_time::BoltDateTimeVisitor,
5+
duration::BoltDurationVisitor,
56
element::ElementDataDeserializer,
67
node::BoltNodeVisitor,
78
path::BoltPathVisitor,
@@ -240,7 +241,9 @@ impl<'de> Visitor<'de> for BoltTypeVisitor {
240241
BoltKind::Path => variant
241242
.tuple_variant(1, BoltPathVisitor)
242243
.map(BoltType::Path),
243-
BoltKind::Duration => variant.tuple_variant(1, self),
244+
BoltKind::Duration => variant
245+
.tuple_variant(1, BoltDurationVisitor)
246+
.map(BoltType::Duration),
244247
BoltKind::Date => variant
245248
.tuple_variant(1, BoltDateTimeVisitor::<BoltDate>::new())
246249
.map(BoltType::Date),
@@ -328,7 +331,7 @@ impl<'de> Deserializer<'de> for BoltTypeDeserializer<'de> {
328331
BoltType::Point3D(p) => p
329332
.into_deserializer()
330333
.deserialize_struct(name, fields, visitor),
331-
BoltType::Duration(d) => visitor.visit_seq(d.seq_access()),
334+
BoltType::Duration(d) => visitor.visit_seq(d.seq_access_external()),
332335
_ => self.unexpected(visitor),
333336
}
334337
}
@@ -360,7 +363,7 @@ impl<'de> Deserializer<'de> for BoltTypeDeserializer<'de> {
360363
BoltType::Point3D(p) => p
361364
.into_deserializer()
362365
.deserialize_newtype_struct(name, visitor),
363-
BoltType::Duration(d) => visitor.visit_seq(d.seq_access()),
366+
BoltType::Duration(d) => visitor.visit_seq(d.seq_access_external()),
364367
BoltType::DateTimeZoneId(dtz) if name == "Timezone" => {
365368
visitor.visit_newtype_struct(BorrowedStrDeserializer::new(dtz.tz_id()))
366369
}
@@ -378,7 +381,8 @@ impl<'de> Deserializer<'de> for BoltTypeDeserializer<'de> {
378381
}
379382
BoltType::Point2D(p) => p.into_deserializer().deserialize_tuple(len, visitor),
380383
BoltType::Point3D(p) => p.into_deserializer().deserialize_tuple(len, visitor),
381-
BoltType::Duration(d) if len == 2 => visitor.visit_seq(d.seq_access()),
384+
BoltType::Duration(d) if len == 2 => visitor.visit_seq(d.seq_access_external()),
385+
BoltType::Duration(d) if len == 4 => visitor.visit_seq(d.seq_access_bolt()),
382386
BoltType::DateTimeZoneId(dtz) => visitor.visit_seq(
383387
dtz.seq_access(
384388
std::any::type_name::<V>()
@@ -879,7 +883,8 @@ impl<'de> VariantAccess<'de> for BoltEnum<'de> {
879883
BoltType::Point3D(p) => BoltPointDeserializer::new(p).deserialize_tuple(len, visitor),
880884
BoltType::Bytes(b) => visitor.visit_borrowed_bytes(&b.value),
881885
BoltType::Path(p) => ElementDataDeserializer::new(p).tuple_variant(len, visitor),
882-
BoltType::Duration(d) => visitor.visit_seq(d.seq_access()),
886+
BoltType::Duration(d) if len == 1 => visitor.visit_seq(d.seq_access_bolt()),
887+
BoltType::Duration(d) => visitor.visit_seq(d.seq_access_external()),
883888
BoltType::Date(d) => visitor.visit_map(d.map_access()),
884889
BoltType::Time(t) => visitor.visit_map(t.map_access()),
885890
BoltType::LocalTime(t) => visitor.visit_map(t.map_access()),
@@ -2007,6 +2012,20 @@ mod tests {
20072012
assert_eq!(actual, duration);
20082013
}
20092014

2015+
#[test]
2016+
fn duration_roundtrip() {
2017+
let duration = BoltDuration::from(Duration::new(42, 1337));
2018+
2019+
let bolt = BoltType::Duration(duration.clone());
2020+
2021+
let actual = bolt.to::<BoltType>().unwrap();
2022+
let BoltType::Duration(actual) = actual else {
2023+
panic!()
2024+
};
2025+
2026+
assert_eq!(actual, duration);
2027+
}
2028+
20102029
fn test_date() -> NaiveDate {
20112030
NaiveDate::from_ymd_opt(1999, 7, 14).unwrap()
20122031
}

lib/tests/duration_deserialization.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
use neo4rs::*;
2+
3+
mod container;
4+
5+
#[tokio::test]
6+
async fn duration_deserialization() {
7+
let neo4j = container::Neo4jContainer::new().await;
8+
let graph = neo4j.graph();
9+
10+
let duration = std::time::Duration::new(5259600, 7);
11+
let mut result = graph
12+
.execute(query("RETURN $d as output").param("d", duration))
13+
.await
14+
.unwrap();
15+
let row = result.next().await.unwrap().unwrap();
16+
let d: std::time::Duration = row.get("output").unwrap();
17+
assert_eq!(d, duration);
18+
19+
let mut result = graph
20+
.execute(query("RETURN $d as output").param("d", duration))
21+
.await
22+
.unwrap();
23+
let row = result.next().await.unwrap().unwrap();
24+
let d = row.get::<BoltType>("output").unwrap();
25+
assert_eq!(
26+
d,
27+
BoltType::Duration(BoltDuration::new(
28+
0.into(),
29+
0.into(),
30+
5259600.into(),
31+
7.into(),
32+
))
33+
);
34+
}

0 commit comments

Comments
 (0)