Skip to content

Commit 663a733

Browse files
committed
Fix propagate serde attributes
1 parent f80db01 commit 663a733

File tree

2 files changed

+106
-21
lines changed

2 files changed

+106
-21
lines changed

src/lib.rs

Lines changed: 87 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
//! pub struct Foo {
2020
//! foo: bool,
2121
//! #[flat_path(path=["a", "b", "c"])]
22+
//! #[serde(skip_serializing_if="Option::is_none")]
2223
//! x: Option<u64>,
2324
//! #[serde(rename="INDEX")]
2425
//! index_number: u32,
@@ -34,7 +35,7 @@ use quote::{format_ident, quote, ToTokens};
3435
use syn::punctuated::Punctuated;
3536
use syn::spanned::Spanned;
3637
use syn::{
37-
parse_quote, Attribute, Error, Field, Fields, ItemEnum, ItemStruct, LitStr, Path, Token,
38+
parse_quote, Attribute, Error, Field, Fields, ItemEnum, ItemStruct, LitStr, Path, Token, Type,
3839
};
3940

4041
mod attr;
@@ -181,6 +182,7 @@ fn perform_simple_flat_path_addition(
181182

182183
paths.push(FlatField {
183184
ident: field_name,
185+
ty: field.ty.clone(),
184186
flat_path,
185187
serde_attributes,
186188
});
@@ -228,13 +230,96 @@ fn generate_flat_path_module(flat_fields: Vec<FlatField>) -> TokenStream2 {
228230

229231
struct FlatField {
230232
ident: Ident,
233+
ty: Type,
231234
flat_path: Vec<LitStr>,
232235
serde_attributes: Vec<Attribute>,
233236
}
234237

235238
impl FlatField {
236239
fn generate_serialize_with(&self) -> TokenStream2 {
237-
self.with_structural_derive()
240+
if self.serde_attributes.is_empty() {
241+
return self.with_structural_derive();
242+
}
243+
244+
// This is more prone to errors due to generics than with_structural_derive, but it is able
245+
// to handle serde bounds properly so it is preferred for those cases.
246+
self.with_concrete_type_derive()
247+
}
248+
249+
fn with_concrete_type_derive(&self) -> TokenStream2 {
250+
let mut tokens = TokenStream2::new();
251+
252+
let ty_tokens = self.ty.clone().into_token_stream();
253+
let serialize_bound = LitStr::new(
254+
&format!("{}: ::serde::Serialize", &ty_tokens),
255+
Span::call_site(),
256+
);
257+
let deserialize_bound = LitStr::new(
258+
&format!("{}: ::serde::de::DeserializeOwned", &ty_tokens),
259+
Span::call_site(),
260+
);
261+
262+
let path_length = self.flat_path.len();
263+
let placeholders = (0..path_length)
264+
.map(|x| format_ident!("_{}", x))
265+
.collect::<Vec<_>>();
266+
for (index, field_name) in self.flat_path[..path_length - 1].iter().enumerate() {
267+
let ident = &placeholders[index];
268+
let next = &placeholders[index + 1];
269+
270+
tokens.extend(quote! {
271+
#[repr(transparent)]
272+
#[derive(::serde::Serialize, ::serde::Deserialize, Default)]
273+
#[serde(bound(serialize = #serialize_bound, deserialize = #deserialize_bound))]
274+
struct #ident {
275+
#[serde(rename=#field_name)]
276+
_0: #next
277+
}
278+
});
279+
}
280+
281+
let last_ident = &placeholders[path_length - 1];
282+
let last_field_name = &self.flat_path[path_length - 1];
283+
let serde_attributes = &self.serde_attributes;
284+
let field_type = &self.ty;
285+
286+
let chain = std::iter::repeat(format_ident!("_0")).take(path_length);
287+
tokens.extend(quote! {
288+
#[repr(transparent)]
289+
#[derive(::serde::Serialize, ::serde::Deserialize, Default)]
290+
#[serde(bound(serialize = #serialize_bound, deserialize = #deserialize_bound))]
291+
struct #last_ident {
292+
#[serde(rename=#last_field_name)]
293+
#(#serde_attributes)*
294+
_0: #field_type
295+
}
296+
297+
#[inline(always)]
298+
pub fn deserialize<'de, D>(deserializer: D) -> Result<#field_type, D::Error>
299+
where #field_type: ::serde::Deserialize<'de>,
300+
D: ::serde::Deserializer<'de>,
301+
{
302+
match <_0 as ::serde::Deserialize>::deserialize(deserializer) {
303+
Ok(value) => Ok(value #(.#chain)*),
304+
Err(e) => Err(e)
305+
}
306+
}
307+
308+
#[inline(always)]
309+
pub fn serialize<S>(this: &#field_type, serializer: S) -> Result<S::Ok, S::Error>
310+
where #field_type: ::serde::Serialize,
311+
S: ::serde::Serializer
312+
{
313+
// # Safety
314+
// This is safe as all members within the chain use repr(transparent) to a value of
315+
// T. Furthermore, data is not accessed via this reference until it is converted
316+
// back to &T at the end of the chain.
317+
let chain_ref = unsafe { ::std::mem::transmute::<&#field_type, &_0>(this) };
318+
::serde::Serialize::serialize(chain_ref, serializer)
319+
}
320+
});
321+
322+
tokens
238323
}
239324

240325
fn with_structural_derive(&self) -> TokenStream2 {

tests/propogate_serde_attrs.rs

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,29 @@ use serde_flat_path::flat_path;
66
#[serde(default)]
77
pub struct Foo {
88
#[flat_path(path=["a", "b", "c"])]
9-
// #[serde(with = "flip_bool")]
9+
#[serde(with = "flip_bool")]
1010
foo: bool,
1111
#[serde(skip_serializing_if = "Option::is_some")]
1212
x: Option<u64>,
1313
}
1414

15-
// mod flip_bool {
16-
// use serde::{Deserialize, Deserializer, Serialize, Serializer};
17-
//
18-
// pub fn deserialize<'de, D>(deserializer: D) -> Result<bool, D::Error>
19-
// where
20-
// D: Deserializer<'de>,
21-
// {
22-
// Ok(!bool::deserialize(deserializer)?)
23-
// }
24-
//
25-
// pub fn serialize<S>(this: &bool, serializer: S) -> Result<S::Ok, S::Error>
26-
// where
27-
// S: Serializer,
28-
// {
29-
// bool::serialize(&!this, serializer)
30-
// }
31-
// }
15+
mod flip_bool {
16+
use serde::{Deserialize, Deserializer, Serialize, Serializer};
17+
18+
pub fn deserialize<'de, D>(deserializer: D) -> Result<bool, D::Error>
19+
where
20+
D: Deserializer<'de>,
21+
{
22+
Ok(!bool::deserialize(deserializer)?)
23+
}
24+
25+
pub fn serialize<S>(this: &bool, serializer: S) -> Result<S::Ok, S::Error>
26+
where
27+
S: Serializer,
28+
{
29+
bool::serialize(&!this, serializer)
30+
}
31+
}
3232

3333
#[test]
3434
fn serialize_deserialize_struct() {
@@ -38,7 +38,7 @@ fn serialize_deserialize_struct() {
3838
};
3939

4040
let json = serde_json::to_string(&foo_initial).unwrap();
41-
assert_eq!(json, r#"{"a":{"b":{"c":false}}}"#);
41+
assert_eq!(json, r#"{"a":{"b":{"c":true}}}"#);
4242

4343
let foo_modified = Foo {
4444
foo: false,

0 commit comments

Comments
 (0)