Skip to content

Commit 066309a

Browse files
committed
Implement GcDeserialize (and correpsonding derive)
The strategy for passing around GcContext to the derive implementation is to stuff it in a thread lcoal. I call it the "horrible hack" in the code, because it uses a ton of unsafe code to transmute lifetimes. Please use something better (or change serde to support contexts).
1 parent 363275d commit 066309a

File tree

17 files changed

+736
-53
lines changed

17 files changed

+736
-53
lines changed

Cargo.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ scopeguard = "1.1"
1515
# gives zerogc batteries included support.
1616
indexmap = { version = "1.6", optional = true }
1717
parking_lot = { version = "0.11", optional = true }
18+
# Serde support (optional)
19+
serde = { version = "1", optional = true, features = ["derive"] }
1820
# Used for macros
1921
zerogc-derive = { path = "libs/derive", version = "0.2.0-alpha.5" }
2022

@@ -25,10 +27,12 @@ members = ["libs/simple", "libs/derive", "libs/context"]
2527
default = ["std"]
2628
# Depend on the standard library (optional)
2729
#
28-
# This implements tracing
30+
# This implements tracing for most standard library types.
2931
std = ["alloc"]
3032
# Depend on `extern crate alloc` in addition to the Rust `core`
3133
# This is implied by using the standard library (feature="std")
3234
#
3335
# This implements `Trace` for `Box` and collections like `Vec`
3436
alloc = []
37+
# Serde support
38+
serde1 = ["serde", "zerogc-derive/__serde-internal"]

libs/context/src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,11 @@ unsafe impl<C: RawCollectorImpl> GcContext for CollectorContext<C> {
210210
result
211211
})
212212
}
213+
214+
#[inline]
215+
fn id(&self) -> Self::Id {
216+
unsafe { (&*self.raw).collector() }.id()
217+
}
213218
}
214219

215220
/// It's not safe for a context to be sent across threads.

libs/derive/Cargo.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ edition = "2018"
1212
proc-macro = true
1313

1414
[dev-dependencies]
15-
zerogc = { version = "0.2.0-alpha.5", path = "../.." }
15+
zerogc = { version = "0.2.0-alpha.5", path = "../..", features = ["serde1"] }
16+
serde = { version = "1" }
1617

1718
[dependencies]
1819
# Proc macros
@@ -25,3 +26,7 @@ proc-macro-kwargs = "0.1.1"
2526
# Misc
2627
indexmap = "1"
2728
itertools = "0.10.1"
29+
30+
[features]
31+
# Indicates that zerogc was compiled with support for serde,
32+
__serde-internal = []

libs/derive/src/derive.rs

Lines changed: 151 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ use crate::{FromLitStr, MetaList};
1212
#[derive(Copy, Clone, Debug)]
1313
pub enum TraceDeriveKind {
1414
NullTrace,
15-
Regular
15+
Regular,
16+
Deserialize
1617
}
1718

1819
trait PossiblyIgnoredParam {
@@ -134,7 +135,7 @@ pub struct TraceTypeParam {
134135
#[darling(skip)]
135136
ignore: bool,
136137
#[darling(skip)]
137-
collector_id: bool
138+
collector_id: bool,
138139
}
139140
impl TraceTypeParam {
140141
fn normalize(&mut self) -> Result<(), Error> {
@@ -188,6 +189,9 @@ struct TraceField {
188189
/// to be traced.
189190
#[darling(default)]
190191
unsafe_skip_trace: bool,
192+
193+
#[darling(forward_attrs(serde))]
194+
attrs: Vec<syn::Attribute>
191195
}
192196
impl TraceField {
193197
fn expand_trace(&self, idx: usize, access: &FieldAccess, immutable: bool) -> TokenStream {
@@ -211,9 +215,12 @@ impl TraceField {
211215
}
212216

213217
#[derive(Debug, FromVariant)]
218+
#[darling(attributes(zerogc))]
214219
struct TraceVariant {
215220
ident: Ident,
216-
fields: darling::ast::Fields<TraceField>
221+
fields: darling::ast::Fields<TraceField>,
222+
#[darling(forward_attrs(serde))]
223+
attrs: Vec<syn::Attribute>
217224
}
218225
impl TraceVariant {
219226
fn fields(&self) -> impl Iterator<Item=&'_ TraceField> + '_ {
@@ -267,27 +274,26 @@ pub struct TraceDeriveInput {
267274
is_copy: bool,
268275
/// If the type should implement `TraceImmutable` in addition to `Trace
269276
#[darling(default, rename = "immutable")]
270-
wants_immutable_trace: bool
277+
wants_immutable_trace: bool,
278+
#[darling(forward_attrs(serde))]
279+
attrs: Vec<syn::Attribute>
271280
}
272281
impl TraceDeriveInput {
273-
pub fn determine_field_types(&self, include_ignored: bool) -> HashSet<Type> {
282+
fn all_fields(&self) -> Vec<&TraceField> {
274283
match self.data {
275284
Data::Enum(ref variants) => {
276-
variants.iter()
277-
.flat_map(|var| var.fields())
278-
.filter(|f| !f.unsafe_skip_trace || include_ignored)
279-
.map(|fd| &fd.ty)
280-
.cloned()
281-
.collect()
282-
}
283-
Data::Struct(ref s) => {
284-
s.fields.iter()
285-
.filter(|f| !f.unsafe_skip_trace || include_ignored)
286-
.map(|f| &f.ty).cloned()
287-
.collect()
288-
}
285+
variants.iter().flat_map(|var| var.fields()).collect()
286+
},
287+
Data::Struct(ref fields) => fields.iter().collect()
289288
}
290289
}
290+
pub fn determine_field_types(&self, include_ignored: bool) -> HashSet<Type> {
291+
self.all_fields().iter()
292+
.filter(|f| !f.unsafe_skip_trace || include_ignored)
293+
.map(|fd| &fd.ty)
294+
.cloned()
295+
.collect()
296+
}
291297
pub fn normalize(&mut self, kind: TraceDeriveKind) -> Result<(), Error> {
292298
if *self.nop_trace {
293299
crate::emit_warning("#[zerogc(nop_trace)] is deprecated (use #[derive(NullTrace)] instead)", self.nop_trace.span())
@@ -330,6 +336,17 @@ impl TraceDeriveInput {
330336
fn gc_lifetime(&self) -> Option<&'_ Lifetime> {
331337
self.generics.gc_lifetime.as_ref()
332338
}
339+
fn generics_with_gc_lifetime(&self, lt: Lifetime) -> (syn::Lifetime, Generics) {
340+
let mut generics = self.generics.original.clone();
341+
let gc_lifetime: syn::Lifetime = match self.gc_lifetime() {
342+
Some(lt) => lt.clone(),
343+
None => {
344+
generics.params.push(GenericParam::Lifetime(LifetimeDef::new(lt.clone())));
345+
lt
346+
}
347+
};
348+
(gc_lifetime, generics)
349+
}
333350
/// Expand a `GcSafe` for a specific combination of `Id` & 'gc
334351
///
335352
/// Implicitly modifies the specified generics
@@ -351,7 +368,8 @@ impl TraceDeriveInput {
351368
},
352369
TraceDeriveKind::NullTrace => {
353370
quote!(zerogc::NullTrace)
354-
}
371+
},
372+
TraceDeriveKind::Deserialize => unreachable!()
355373
};
356374
for tp in self.generics.regular_type_params() {
357375
let tp = &tp.ident;
@@ -363,12 +381,14 @@ impl TraceDeriveInput {
363381
generics.make_where_clause().predicates.push(
364382
parse_quote!(#tp: #requirement)
365383
)
366-
}
384+
},
385+
TraceDeriveKind::Deserialize => unreachable!()
367386
}
368387
}
369388
let assertion: Ident = match kind {
370389
TraceDeriveKind::NullTrace => parse_quote!(verify_null_trace),
371-
TraceDeriveKind::Regular => parse_quote!(assert_gc_safe)
390+
TraceDeriveKind::Regular => parse_quote!(assert_gc_safe),
391+
TraceDeriveKind::Deserialize => unreachable!()
372392
};
373393
let ty_generics = self.generics.original.split_for_impl().1;
374394
let (impl_generics, _, where_clause) = generics.split_for_impl();
@@ -383,14 +403,7 @@ impl TraceDeriveInput {
383403
})
384404
}
385405
fn expand_gcsafe(&self, kind: TraceDeriveKind) -> Result<TokenStream, Error> {
386-
let mut generics = self.generics.original.clone();
387-
let gc_lifetime: syn::Lifetime = match self.gc_lifetime() {
388-
Some(lt) => lt.clone(),
389-
None => {
390-
generics.params.push(parse_quote!('gc));
391-
parse_quote!('gc)
392-
}
393-
};
406+
let (gc_lifetime, mut generics) = self.generics_with_gc_lifetime(parse_quote!('gc));
394407
match kind {
395408
TraceDeriveKind::NullTrace => {
396409
// Verify we don't have any explicit collector id
@@ -410,9 +423,108 @@ impl TraceDeriveInput {
410423
self.expand_gcsafe_sepcific(kind, initial, id, gc_lt)
411424
}
412425
)
413-
}
426+
},
427+
TraceDeriveKind::Deserialize => unreachable!()
414428
}
415429
}
430+
431+
fn expand_deserialize(&self) -> Result<TokenStream, Error> {
432+
if !crate::DESERIALIZE_ENABLED {
433+
return Err(Error::custom("The `zerogc/serde1` feature is disabled (please enable it)"));
434+
}
435+
let (gc_lifetime, generics) = self.generics_with_gc_lifetime(parse_quote!('gc));
436+
self.expand_for_each_regular_id(
437+
generics, TraceDeriveKind::Deserialize, gc_lifetime,
438+
&mut |kind, initial, id, gc_lt| {
439+
assert!(matches!(kind, TraceDeriveKind::Deserialize));
440+
let id_is_generic = self.generics.original.type_params()
441+
.any(|param| id.is_ident(&param.ident));
442+
let mut generics = initial.unwrap();
443+
generics.params.push(parse_quote!('deserialize));
444+
let requirement = quote!(for<'deser2> zerogc::serde::GcDeserialize::<#gc_lt, 'deser2, #id>);
445+
for target in self.generics.regular_type_params() {
446+
let target = &target.ident;
447+
generics.make_where_clause().predicates.push(parse_quote!(#target: #requirement));
448+
}
449+
let ty_generics = self.generics.original.split_for_impl().1;
450+
let (impl_generics, _, where_clause) = generics.split_for_impl();
451+
let target_type = &self.ident;
452+
let forward_attrs = &self.attrs;
453+
let deserialize_field = |f: &TraceField| {
454+
let named = f.ident.as_ref().map(|name| quote!(#name: ));
455+
let ty = &f.ty;
456+
let forwarded_attrs = &f.attrs;
457+
let bound = format!(
458+
"{}: for<'deserialize> zerogc::serde::GcDeserialize<{}, 'deserialize, {}>", ty.to_token_stream(),
459+
gc_lt.to_token_stream(), id.to_token_stream()
460+
);
461+
quote! {
462+
#(#forwarded_attrs)*
463+
# [serde(deserialize_with = "deserialize_hack", bound(deserialize = #bound))]
464+
#named #ty
465+
}
466+
};
467+
let handle_fields = |fields: &darling::ast::Fields<TraceField>| {
468+
let handled_fields = fields.fields.iter().map(deserialize_field);
469+
match fields.style {
470+
Style::Tuple => {
471+
quote!{ ( #(#handled_fields),* ) }
472+
}
473+
Style::Struct => {
474+
quote!({ #(#handled_fields),* })
475+
}
476+
Style::Unit => quote!()
477+
}
478+
};
479+
let original_generics = &self.generics.original;
480+
let inner = match self.data {
481+
Data::Enum(ref variants) => {
482+
let variants = variants.iter().map(|v| {
483+
let forward_attrs = &v.attrs;
484+
let name = &v.ident;
485+
let inner = handle_fields(&v.fields);
486+
quote! {
487+
#(#forward_attrs)*
488+
#name #inner
489+
}
490+
});
491+
quote!(enum HackRemoteDeserialize #original_generics { #(#variants),* })
492+
}
493+
Data::Struct(ref f) => {
494+
let fields = handle_fields(f);
495+
quote!(struct HackRemoteDeserialize #original_generics # fields)
496+
}
497+
};
498+
let remote_name = target_type.to_token_stream().to_string();
499+
let id_decl = if id_is_generic {
500+
Some(quote!(#id: zerogc::CollectorId,))
501+
} else { None };
502+
Ok(quote! {
503+
impl #impl_generics zerogc::serde::GcDeserialize<#gc_lt, 'deserialize, #id> for #target_type #ty_generics #where_clause {
504+
fn deserialize_gc<D: serde::Deserializer<'deserialize>>(ctx: &#gc_lt <<#id as zerogc::CollectorId>::System as zerogc::GcSystem>::Context, deserializer: D) -> Result<Self, D::Error> {
505+
use serde::Deserializer;
506+
let _guard = unsafe { zerogc::serde::hack::set_context(ctx) };
507+
unsafe {
508+
debug_assert_eq!(_guard.get_unchecked() as *const _, ctx as *const _);
509+
}
510+
/// Hack function to deserialize via `serde::hack`, with the appropriate `Id` type
511+
///
512+
/// Needed because the actual function is unsafe
513+
#[track_caller]
514+
fn deserialize_hack<'gc, 'de, #id_decl D: serde::de::Deserializer<'de>, T: zerogc::serde::GcDeserialize<#gc_lt, 'de, #id>>(deser: D) -> Result<T, D::Error> {
515+
unsafe { zerogc::serde::hack::unchecked_deserialize_hack::<'gc, 'de, D, #id, T>(deser) }
516+
}
517+
# [derive(serde::Deserialize)]
518+
# [serde(remote = #remote_name)]
519+
#(#forward_attrs)*
520+
#inner ;
521+
HackRemoteDeserialize::deserialize(deserializer)
522+
}
523+
}
524+
})
525+
}
526+
)
527+
}
416528
fn expand_for_each_regular_id(
417529
&self, generics: Generics,
418530
kind: TraceDeriveKind,
@@ -669,6 +781,11 @@ impl TraceDeriveInput {
669781
})
670782
}
671783
fn expand_trusted_drop(&self, kind: TraceDeriveKind) -> TokenStream {
784+
let mut generics = self.generics.original.clone();
785+
for param in self.generics.regular_type_params() {
786+
let name = &param.ident;
787+
generics.make_where_clause().predicates.push(parse_quote!(#name: zerogc::TrustedDrop));
788+
}
672789
#[allow(clippy::if_same_then_else)] // Only necessary because of detailed comment
673790
let protective_drop = if self.is_copy {
674791
/*
@@ -703,7 +820,7 @@ impl TraceDeriveInput {
703820
}))
704821
};
705822
let target_type = &self.ident;
706-
let (impl_generics, ty_generics, where_clause) = self.generics.original.split_for_impl();
823+
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
707824
quote! {
708825
#protective_drop
709826
unsafe impl #impl_generics zerogc::TrustedDrop for #target_type #ty_generics #where_clause {}
@@ -832,6 +949,9 @@ impl TraceDeriveInput {
832949
})
833950
}
834951
pub fn expand(&self, kind: TraceDeriveKind) -> Result<TokenStream, Error> {
952+
if matches!(kind, TraceDeriveKind::Deserialize) {
953+
return self.expand_deserialize();
954+
}
835955
let gcsafe = self.expand_gcsafe(kind)?;
836956
let trace_immutable = if self.wants_immutable_trace {
837957
Some(self.expand_trace(kind, true)?)

libs/derive/src/lib.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,22 @@ pub fn derive_trace(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
135135
res
136136
}
137137

138+
pub(crate) const DESERIALIZE_ENABLED: bool = cfg!(feature = "__serde-internal");
139+
140+
#[proc_macro_derive(GcDeserialize, attributes(zerogc))]
141+
pub fn gc_deserialize(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
142+
let input = parse_macro_input!(input as DeriveInput);
143+
let res = From::from(impl_derive_trace(&input, TraceDeriveKind::Deserialize)
144+
.unwrap_or_else(|e| e.write_errors()));
145+
debug_derive(
146+
"derive(GcDeserialize)",
147+
&input.ident.to_string(),
148+
&format_args!("#[derive(GcDeserialize) for {}", input.ident),
149+
&res
150+
);
151+
res
152+
}
153+
138154
fn impl_derive_trace(input: &DeriveInput, kind: TraceDeriveKind) -> Result<TokenStream, darling::Error> {
139155
let mut input = derive::TraceDeriveInput::from_derive_input(input)?;
140156
input.normalize(kind)?;

0 commit comments

Comments
 (0)