Skip to content

Commit 5f76a12

Browse files
committed
Add #[zerogc(serde(...))] opts for GcDeserialize
Allows overriding automatically inferred #[serde()] opts.
1 parent 02acd89 commit 5f76a12

File tree

3 files changed

+168
-35
lines changed

3 files changed

+168
-35
lines changed

libs/derive/src/derive.rs

Lines changed: 131 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use darling::{Error, FromMeta, FromGenerics, FromTypeParam, FromDeriveInput, FromVariant, FromField};
22
use proc_macro2::{Ident, TokenStream, Span};
3-
use syn::{Generics, Type, GenericParam, TypeParam, Lifetime, Path, parse_quote, PathArguments, GenericArgument, TypePath, Meta, LifetimeDef};
3+
use syn::{GenericArgument, GenericParam, Generics, Lifetime, LifetimeDef, LitStr, Meta, Path, PathArguments, Type, TypeParam, TypePath, parse_quote};
44
use darling::util::{SpannedValue};
55
use quote::{quote_spanned, quote, format_ident, ToTokens};
66
use darling::ast::{Style, Data};
@@ -189,7 +189,8 @@ struct TraceField {
189189
/// to be traced.
190190
#[darling(default)]
191191
unsafe_skip_trace: bool,
192-
192+
#[darling(default, rename = "serde")]
193+
serde_opts: Option<SerdeFieldOpts>,
193194
#[darling(forward_attrs(serde))]
194195
attrs: Vec<syn::Attribute>
195196
}
@@ -249,7 +250,64 @@ impl TraceVariant {
249250
}
250251
}
251252

252-
#[derive(FromDeriveInput)]
253+
/// Custom `#[serde(bound(deserialize = ""))]
254+
#[derive(Debug, Clone, FromMeta)]
255+
struct CustomSerdeBounds {
256+
/// The custom deserialize bound
257+
deserialize: LitStr
258+
}
259+
260+
/// Options for `#[zerogc(serde)]` on a type
261+
#[derive(Debug, Clone, Default, FromMeta)]
262+
struct SerdeTypeOpts {
263+
/// Delegate directly to the `Deserialize` implementation,
264+
/// without generating a wrapper.
265+
///
266+
/// Effectively calls `zerogc::derive_delegating_deserialize!`
267+
///
268+
/// Requires `Self: serde::Deserialize`
269+
///
270+
/// If this is present,
271+
/// then all other options are ignored.
272+
#[darling(default)]
273+
delegate: bool,
274+
/// Override the inferred bounds
275+
///
276+
/// Equivalent to `#[serde(bound(....))]`
277+
#[darling(default, rename = "bound")]
278+
custom_bounds: Option<CustomSerdeBounds>,
279+
}
280+
281+
#[derive(Debug, Clone, Default, FromMeta)]
282+
struct SerdeFieldOpts {
283+
/// Delegate to the `serde::Deserialize`
284+
/// implementation instead of using `GcDeserialize`
285+
///
286+
/// If this option is present,
287+
/// then all other options are ignored.
288+
#[darling(default)]
289+
delegate: bool,
290+
/// Override the inferred bounds for the field.
291+
#[darling(default, rename = "bound")]
292+
custom_bounds: Option<CustomSerdeBounds>,
293+
/// Deserialize this field using a custom
294+
/// deserialization function.
295+
///
296+
/// Equivalent to `#[serde(deserialize_with = "...")]`
297+
#[darling(default)]
298+
deserialize_with: Option<LitStr>,
299+
/// Skip deserializing this field.
300+
///
301+
/// Equivalent to `#[serde(skip_deserializing)]`.
302+
///
303+
/// May choose to override the default with a
304+
/// regular `#[serde(default = "...")]`
305+
/// (but not with the #[zerogc(serde(...))])` syntax)
306+
#[darling(default)]
307+
skip_deserializing: bool
308+
}
309+
310+
#[derive(Debug, FromDeriveInput)]
253311
#[darling(attributes(zerogc))]
254312
pub struct TraceDeriveInput {
255313
pub ident: Ident,
@@ -275,6 +333,8 @@ pub struct TraceDeriveInput {
275333
/// If the type should implement `TraceImmutable` in addition to `Trace
276334
#[darling(default, rename = "immutable")]
277335
wants_immutable_trace: bool,
336+
#[darling(default, rename = "serde")]
337+
serde_opts: Option<SerdeTypeOpts>,
278338
#[darling(forward_attrs(serde))]
279339
attrs: Vec<syn::Attribute>
280340
}
@@ -437,14 +497,17 @@ impl TraceDeriveInput {
437497
generics, TraceDeriveKind::Deserialize, gc_lifetime,
438498
&mut |kind, initial, id, gc_lt| {
439499
assert!(matches!(kind, TraceDeriveKind::Deserialize));
500+
let type_opts = self.serde_opts.clone().unwrap_or_default();
440501
let mut generics = initial.unwrap();
441502
let id_is_generic = generics.type_params()
442503
.any(|param| id.is_ident(&param.ident));
443504
generics.params.push(parse_quote!('deserialize));
444505
let requirement = quote!(for<'deser2> zerogc::serde::GcDeserialize::<#gc_lt, 'deser2, #id>);
445-
for target in self.generics.regular_type_params() {
446-
let target = &target.ident;
447-
generics.make_where_clause().predicates.push(parse_quote!(#target: #requirement));
506+
if !type_opts.delegate {
507+
for target in self.generics.regular_type_params() {
508+
let target = &target.ident;
509+
generics.make_where_clause().predicates.push(parse_quote!(#target: #requirement));
510+
}
448511
}
449512
let ty_generics = self.generics.original.split_for_impl().1;
450513
let (impl_generics, _, where_clause) = generics.split_for_impl();
@@ -454,13 +517,32 @@ impl TraceDeriveInput {
454517
let named = f.ident.as_ref().map(|name| quote!(#name: ));
455518
let ty = &f.ty;
456519
let forwarded_attrs = &f.attrs;
457-
let bound = format!(
458-
"{}: for<'deserialize> zerogc::serde::GcDeserialize<{}, 'deserialize, {}>", ty.to_token_stream(),
459-
gc_lt.to_token_stream(), id.to_token_stream()
460-
);
520+
let serde_opts = f.serde_opts.clone().unwrap_or_default();
521+
let serde_attr = if serde_opts.delegate {
522+
quote!()
523+
} else {
524+
let deserialize_with = serde_opts.deserialize_with.as_ref().map_or_else(
525+
|| String::from("deserialize_hack"),
526+
|with| with.value()
527+
);
528+
let custom_bound = if serde_opts.skip_deserializing || serde_opts.deserialize_with.is_some() {
529+
quote!()
530+
} else {
531+
let bound = serde_opts.custom_bounds
532+
.as_ref().map_or_else(
533+
|| format!(
534+
"{}: for<'deserialize> zerogc::serde::GcDeserialize<{}, 'deserialize, {}>", ty.to_token_stream(),
535+
gc_lt.to_token_stream(), id.to_token_stream()
536+
),
537+
|bounds| bounds.deserialize.value()
538+
);
539+
quote!(, bound(deserialize = #bound))
540+
};
541+
quote!(# [serde(deserialize_with = #deserialize_with #custom_bound)])
542+
};
461543
quote! {
462544
#(#forwarded_attrs)*
463-
# [serde(deserialize_with = "deserialize_hack", bound(deserialize = #bound))]
545+
#serde_attr
464546
#named #ty
465547
}
466548
};
@@ -499,29 +581,46 @@ impl TraceDeriveInput {
499581
let id_decl = if id_is_generic {
500582
Some(quote!(#id: zerogc::CollectorId,))
501583
} else { None };
502-
Ok(quote! {
503-
impl #impl_generics zerogc::serde::GcDeserialize<#gc_lt, 'deserialize, #id> for #target_type #ty_generics #where_clause {
504-
fn deserialize_gc<D: serde::Deserializer<'deserialize>>(ctx: &#gc_lt <<#id as zerogc::CollectorId>::System as zerogc::GcSystem>::Context, deserializer: D) -> Result<Self, D::Error> {
505-
use serde::Deserializer;
506-
let _guard = unsafe { zerogc::serde::hack::set_context(ctx) };
507-
unsafe {
508-
debug_assert_eq!(_guard.get_unchecked() as *const _, ctx as *const _);
509-
}
510-
/// Hack function to deserialize via `serde::hack`, with the appropriate `Id` type
511-
///
512-
/// Needed because the actual function is unsafe
513-
#[track_caller]
514-
fn deserialize_hack<'gc, 'de, #id_decl D: serde::de::Deserializer<'de>, T: zerogc::serde::GcDeserialize<#gc_lt, 'de, #id>>(deser: D) -> Result<T, D::Error> {
515-
unsafe { zerogc::serde::hack::unchecked_deserialize_hack::<'gc, 'de, D, #id, T>(deser) }
584+
if !type_opts.delegate && !original_generics.lifetimes().any(|lt| lt.lifetime == *gc_lt) {
585+
return Err(Error::custom("No 'gc lifetime found during #[derive(GcDeserialize)]. Consider #[zerogc(serde(delegate))] or a PhantomData."))
586+
}
587+
if type_opts.delegate {
588+
Ok(quote! {
589+
impl #impl_generics zerogc::serde::GcDeserialize<#gc_lt, 'deserialize, #id> for #target_type #ty_generics #where_clause {
590+
fn deserialize_gc<D: serde::Deserializer<'deserialize>>(_ctx: &#gc_lt <<#id as zerogc::CollectorId>::System as zerogc::GcSystem>::Context, deserializer: D) -> Result<Self, D::Error> {
591+
<Self as serde::Deserialize<'deserialize>>::deserialize(deserializer)
592+
}
593+
}
594+
})
595+
} else {
596+
let custom_bound = type_opts.custom_bounds.as_ref().map(|bounds| {
597+
let de_bounds = bounds.deserialize.value();
598+
quote!(, bound(deserialize = #de_bounds))
599+
});
600+
Ok(quote! {
601+
impl #impl_generics zerogc::serde::GcDeserialize<#gc_lt, 'deserialize, #id> for #target_type #ty_generics #where_clause {
602+
fn deserialize_gc<D: serde::Deserializer<'deserialize>>(ctx: &#gc_lt <<#id as zerogc::CollectorId>::System as zerogc::GcSystem>::Context, deserializer: D) -> Result<Self, D::Error> {
603+
use serde::Deserializer;
604+
let _guard = unsafe { zerogc::serde::hack::set_context(ctx) };
605+
unsafe {
606+
debug_assert_eq!(_guard.get_unchecked() as *const _, ctx as *const _);
607+
}
608+
/// Hack function to deserialize via `serde::hack`, with the appropriate `Id` type
609+
///
610+
/// Needed because the actual function is unsafe
611+
#[track_caller]
612+
fn deserialize_hack<'gc, 'de, #id_decl D: serde::de::Deserializer<'de>, T: zerogc::serde::GcDeserialize<#gc_lt, 'de, #id>>(deser: D) -> Result<T, D::Error> {
613+
unsafe { zerogc::serde::hack::unchecked_deserialize_hack::<'gc, 'de, D, #id, T>(deser) }
614+
}
615+
# [derive(serde::Deserialize)]
616+
# [serde(remote = #remote_name #custom_bound)]
617+
#(#forward_attrs)*
618+
#inner ;
619+
HackRemoteDeserialize::deserialize(deserializer)
516620
}
517-
# [derive(serde::Deserialize)]
518-
# [serde(remote = #remote_name)]
519-
#(#forward_attrs)*
520-
#inner ;
521-
HackRemoteDeserialize::deserialize(deserializer)
522621
}
523-
}
524-
})
622+
})
623+
}
525624
}
526625
)
527626
}

libs/derive/tests/deserialize.rs

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
use std::marker::PhantomData;
2+
13
use zerogc_derive::{GcDeserialize, NullTrace, Trace};
24

35
use zerogc::SimpleAllocCollectorId;
46
use zerogc::prelude::*;
57
use zerogc::epsilon::{EpsilonCollectorId};
8+
use serde::Deserialize;
69

710
#[derive(Trace, GcDeserialize)]
811
#[zerogc(collector_ids(EpsilonCollectorId))]
@@ -16,7 +19,39 @@ struct DeserializeParameterized<'gc, T: GcSafe<'gc, Id>, Id: SimpleAllocCollecto
1619
test: Gc<'gc, Vec<T>, Id>
1720
}
1821

19-
#[derive(NullTrace, GcDeserialize)]
20-
enum PlainEnum {
22+
#[derive(NullTrace, GcDeserialize, Deserialize)]
23+
#[zerogc(serde(delegate))]
24+
#[allow(unused)]
25+
struct DelegatingDeserialize {
26+
foo: String,
27+
bar: i32,
28+
doesnt: DoesntImplGcDeserialize,
29+
}
30+
31+
32+
#[derive(Trace, GcDeserialize)]
33+
#[allow(unused)]
34+
#[zerogc(collector_ids(Id))]
35+
struct DeserializeWith<'gc, Id: CollectorId> {
36+
foo: String,
37+
#[zerogc(serde(delegate))]
38+
doesnt_gc_deser: DoesntImplGcDeserialize,
39+
#[zerogc(serde(deserialize_with = "but_its_a_unit", bound(deserialize = "")))]
40+
doesnt_deser_at_all: DoesntDeserAtAll,
41+
marker: PhantomData<&'gc Id>
42+
}
43+
44+
#[derive(NullTrace, serde::Deserialize)]
45+
#[allow(unused)]
46+
struct DoesntImplGcDeserialize {
47+
foo: String
48+
}
49+
50+
fn but_its_a_unit<'de, D: serde::Deserializer<'de>>(_deser: D) -> Result<DoesntDeserAtAll, D::Error> {
51+
Ok(DoesntDeserAtAll {})
52+
}
53+
54+
#[derive(NullTrace)]
55+
struct DoesntDeserAtAll {
2156

2257
}

src/serde.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
use std::marker::PhantomData;
1313

1414
use crate::array::{GcArray, GcString};
15-
use indexmap::set::Union;
1615
use serde::{Serialize, de::{self, Deserializer, Visitor, DeserializeSeed}, ser::SerializeSeq};
1716

1817
use crate::prelude::*;

0 commit comments

Comments
 (0)