Skip to content

Commit 846465c

Browse files
committed
GcDeserialize should infer 'GcSimpleAlloc' bound
This is nessicarry for the `GcDeserialize for Gc` impl.
1 parent 5f76a12 commit 846465c

File tree

2 files changed

+54
-4
lines changed

2 files changed

+54
-4
lines changed

libs/derive/src/derive.rs

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,15 @@ struct SerdeTypeOpts {
276276
/// Equivalent to `#[serde(bound(....))]`
277277
#[darling(default, rename = "bound")]
278278
custom_bounds: Option<CustomSerdeBounds>,
279+
/// Require that Id::System::Context: GcSimpleAlloc
280+
///
281+
/// This is necessary for the standard implementation
282+
/// of `GcDeserialize for Gc` to apply.
283+
///
284+
/// It is automatically inferred if you have any `Gc`, `GcArray`
285+
/// or `GcString` fields (ignoring fully qualified paths).
286+
#[darling(default)]
287+
require_simple_alloc: bool,
279288
}
280289

281290
#[derive(Debug, Clone, Default, FromMeta)]
@@ -493,6 +502,30 @@ impl TraceDeriveInput {
493502
return Err(Error::custom("The `zerogc/serde1` feature is disabled (please enable it)"));
494503
}
495504
let (gc_lifetime, generics) = self.generics_with_gc_lifetime(parse_quote!('gc));
505+
let should_require_simple_alloc = self.all_fields().iter().any(|field| {
506+
let is_gc_allocated = match field.ty {
507+
Type::Path(ref p) if p.path.segments.len() == 1 => {
508+
/*
509+
* If we exactly match 'Gc', 'GcArray' or 'GcString',
510+
* then it can be assumed we are garbage collected
511+
* and that we should require Id::System::Context: GcSimpleAlloc
512+
*/
513+
let name = &p.path.segments.last().unwrap().ident;
514+
name == "Gc" || name == "GcArrray" || name == "GcString"
515+
},
516+
_ => false
517+
};
518+
if let Some(ref custom_opts) = field.serde_opts {
519+
if custom_opts.delegate || custom_opts.skip_deserializing ||
520+
custom_opts.custom_bounds.is_some() || custom_opts.deserialize_with.is_some() {
521+
return false;
522+
}
523+
}
524+
is_gc_allocated
525+
}) || self.serde_opts.as_ref().map_or(false, |opts| opts.require_simple_alloc);
526+
let do_require_simple_alloc = |id: &dyn ToTokens| {
527+
quote!(<<#id as zerogc::CollectorId>::System as zerogc::GcSystem>::Context: zerogc::GcSimpleAlloc)
528+
};
496529
self.expand_for_each_regular_id(
497530
generics, TraceDeriveKind::Deserialize, gc_lifetime,
498531
&mut |kind, initial, id, gc_lt| {
@@ -508,6 +541,11 @@ impl TraceDeriveInput {
508541
let target = &target.ident;
509542
generics.make_where_clause().predicates.push(parse_quote!(#target: #requirement));
510543
}
544+
if should_require_simple_alloc {
545+
generics.make_where_clause().predicates.push(
546+
syn::parse2(do_require_simple_alloc(&id)).unwrap()
547+
);
548+
}
511549
}
512550
let ty_generics = self.generics.original.split_for_impl().1;
513551
let (impl_generics, _, where_clause) = generics.split_for_impl();
@@ -593,10 +631,20 @@ impl TraceDeriveInput {
593631
}
594632
})
595633
} else {
596-
let custom_bound = type_opts.custom_bounds.as_ref().map(|bounds| {
634+
let custom_bound = if let Some(ref bounds) = type_opts.custom_bounds {
597635
let de_bounds = bounds.deserialize.value();
598636
quote!(, bound(deserialize = #de_bounds))
599-
});
637+
} else if should_require_simple_alloc {
638+
let de_bounds = format!("{}", do_require_simple_alloc(&id));
639+
quote!(, bound(deserialize = #de_bounds))
640+
} else {
641+
quote!()
642+
};
643+
let hack_where_bound = if should_require_simple_alloc && id_is_generic {
644+
do_require_simple_alloc(&quote!(Id))
645+
} else {
646+
quote!()
647+
};
600648
Ok(quote! {
601649
impl #impl_generics zerogc::serde::GcDeserialize<#gc_lt, 'deserialize, #id> for #target_type #ty_generics #where_clause {
602650
fn deserialize_gc<D: serde::Deserializer<'deserialize>>(ctx: &#gc_lt <<#id as zerogc::CollectorId>::System as zerogc::GcSystem>::Context, deserializer: D) -> Result<Self, D::Error> {
@@ -609,7 +657,8 @@ impl TraceDeriveInput {
609657
///
610658
/// Needed because the actual function is unsafe
611659
#[track_caller]
612-
fn deserialize_hack<'gc, 'de, #id_decl D: serde::de::Deserializer<'de>, T: zerogc::serde::GcDeserialize<#gc_lt, 'de, #id>>(deser: D) -> Result<T, D::Error> {
660+
fn deserialize_hack<'gc, 'de, #id_decl D: serde::de::Deserializer<'de>, T: zerogc::serde::GcDeserialize<#gc_lt, 'de, #id>>(deser: D) -> Result<T, D::Error>
661+
where #hack_where_bound {
613662
unsafe { zerogc::serde::hack::unchecked_deserialize_hack::<'gc, 'de, D, #id, T>(deser) }
614663
}
615664
# [derive(serde::Deserialize)]

libs/derive/tests/deserialize.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ struct DeserializeWith<'gc, Id: CollectorId> {
3838
doesnt_gc_deser: DoesntImplGcDeserialize,
3939
#[zerogc(serde(deserialize_with = "but_its_a_unit", bound(deserialize = "")))]
4040
doesnt_deser_at_all: DoesntDeserAtAll,
41-
marker: PhantomData<&'gc Id>
41+
marker: PhantomData<&'gc Id>,
42+
deser: Gc<'gc, String, Id>
4243
}
4344

4445
#[derive(NullTrace, serde::Deserialize)]

0 commit comments

Comments
 (0)