Skip to content

Commit f3b7556

Browse files
committed
[derive] Implement safety checks for GcErase + GCRebrand
Remove those nasty warnings Add a #[zerogc(collector_id(Id))] to mark which specific collector you support rebranding for. If you have a `Gc<'gc, SpecificId>` than It's not safe to implement a general `for<I: CollectorId> GcErase<'a, I>` because `I` might not match `SpecificId`. This was caught with the new verification! Horay for type safety!
1 parent 161f5d0 commit f3b7556

File tree

3 files changed

+158
-44
lines changed

3 files changed

+158
-44
lines changed

libs/derive/src/lib.rs

Lines changed: 135 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#![feature(
2-
proc_macro_diagnostic, // Used for warnings
32
proc_macro_tracked_env, // Used for `DEBUG_DERIVE`
43
)]
54
extern crate proc_macro;
@@ -117,6 +116,7 @@ struct TypeAttrs {
117116
is_copy: bool,
118117
nop_trace: bool,
119118
gc_lifetime: Option<Lifetime>,
119+
collector_id: Option<Ident>,
120120
ignore_params: HashSet<Ident>,
121121
ignored_lifetimes: HashSet<Lifetime>
122122
}
@@ -143,6 +143,7 @@ impl Default for TypeAttrs {
143143
is_copy: false,
144144
nop_trace: false,
145145
gc_lifetime: None,
146+
collector_id: None,
146147
ignore_params: Default::default(),
147148
ignored_lifetimes: Default::default(),
148149
}
@@ -211,6 +212,32 @@ impl Parse for TypeAttrs {
211212
}
212213
};
213214
result.gc_lifetime = Some(lifetime);
215+
} else if meta.path().is_ident("collector_id") {
216+
if result.collector_id.is_some() {
217+
return Err(Error::new(
218+
meta.span(),
219+
"Duplicate flags: #[zerogc(collector_id)]"
220+
))
221+
}
222+
fn get_ident_meta(meta: &NestedMeta) -> Option<&Ident> {
223+
match *meta {
224+
NestedMeta::Meta(Meta::Path(ref p)) => p.get_ident(),
225+
_ => None
226+
}
227+
}
228+
let ident = match meta {
229+
Meta::List(ref l) if l.nested.len() == 1
230+
&& get_ident_meta(&l.nested[0]).is_some() => {
231+
get_ident_meta(&l.nested[0]).unwrap()
232+
}
233+
_ => {
234+
return Err(Error::new(
235+
meta.span(),
236+
"Malformed attribute for #[zerogc(collector_id)]"
237+
))
238+
}
239+
};
240+
result.collector_id = Some(ident.clone());
214241
} else if meta.path().is_ident("ignore_params") {
215242
if !result.ignore_params.is_empty() {
216243
return Err(Error::new(
@@ -250,7 +277,7 @@ impl Parse for TypeAttrs {
250277
}
251278
}
252279
} else if meta.path().is_ident("ignore_lifetimes") {
253-
if !result.ignore_params.is_empty() {
280+
if !result.ignored_lifetimes.is_empty() {
254281
return Err(Error::new(
255282
meta.span(),
256283
"Duplicate flags: #[zerogc(ignore_lifetimes)]"
@@ -505,7 +532,13 @@ fn impl_erase_nop(target: &DeriveInput, info: &GcTypeInfo) -> Result<TokenStream
505532
}
506533
let mut impl_generics = generics.clone();
507534
impl_generics.params.push(GenericParam::Lifetime(parse_quote!('min)));
508-
impl_generics.params.push(GenericParam::Type(parse_quote!(S: ::zerogc::CollectorId)));
535+
let collector_id = match info.config.collector_id {
536+
Some(ref id) => id.clone(),
537+
None => {
538+
impl_generics.params.push(GenericParam::Type(parse_quote!(Id: ::zerogc::CollectorId)));
539+
parse_quote!(Id)
540+
}
541+
};
509542
// Require that `Self: NullTrace`
510543
impl_generics.make_where_clause().predicates.push(WherePredicate::Type(PredicateType {
511544
lifetimes: None,
@@ -515,14 +548,8 @@ fn impl_erase_nop(target: &DeriveInput, info: &GcTypeInfo) -> Result<TokenStream
515548
}));
516549
let (_, ty_generics, _) = generics.split_for_impl();
517550
let (impl_generics, _, where_clause) = impl_generics.split_for_impl();
518-
::proc_macro::Diagnostic::spanned(
519-
::proc_macro::Span::call_site(),
520-
::proc_macro::Level::Note,
521-
// We know this is safe because we know that `Self: NullTrace`
522-
"derive(GcRebrand) is safe for NullTrace, unlike standard implementation"
523-
).emit();
524551
Ok(quote! {
525-
unsafe impl #impl_generics ::zerogc::GcErase<'min, S>
552+
unsafe impl #impl_generics ::zerogc::GcErase<'min, #collector_id>
526553
for #name #ty_generics #where_clause {
527554
// We can pass-through because we are NullTrace
528555
type Erased = Self;
@@ -534,15 +561,19 @@ fn impl_erase(target: &DeriveInput, info: &GcTypeInfo) -> Result<TokenStream, Er
534561
let mut generics: Generics = target.generics.clone();
535562
let mut rewritten_params = Vec::new();
536563
let mut rewritten_restrictions = Vec::new();
564+
let collector_id = match info.config.collector_id {
565+
Some(ref id) => id.clone(),
566+
None => parse_quote!(Id)
567+
};
537568
for param in &mut generics.params {
538569
let rewritten_param: GenericArgument;
539570
match param {
540571
GenericParam::Type(ref mut type_param) => {
541572
let original_bounds = type_param.bounds.iter().cloned().collect::<Vec<_>>();
542-
type_param.bounds.push(parse_quote!(::zerogc::GcErase<'min, S>));
573+
type_param.bounds.push(parse_quote!(::zerogc::GcErase<'min, #collector_id>));
543574
type_param.bounds.push(parse_quote!('min));
544575
let param_name = &type_param.ident;
545-
let rewritten_type: Type = parse_quote!(<#param_name as ::zerogc::GcErase<'min, S>>::Erased);
576+
let rewritten_type: Type = parse_quote!(<#param_name as ::zerogc::GcErase<'min, #collector_id>>::Erased);
546577
rewritten_restrictions.push(WherePredicate::Type(PredicateType {
547578
lifetimes: None,
548579
bounded_ty: rewritten_type.clone(),
@@ -569,21 +600,44 @@ fn impl_erase(target: &DeriveInput, info: &GcTypeInfo) -> Result<TokenStream, Er
569600
}
570601
rewritten_params.push(rewritten_param);
571602
}
603+
let mut field_types = Vec::new();
604+
match target.data {
605+
Data::Struct(ref s) => {
606+
for f in &s.fields {
607+
field_types.push(f.ty.clone());
608+
}
609+
},
610+
Data::Enum(ref e) => {
611+
for variant in &e.variants {
612+
for f in &variant.fields {
613+
field_types.push(f.ty.clone());
614+
}
615+
}
616+
},
617+
Data::Union(_) => {
618+
return Err(Error::new(target.ident.span(), "Unable to derive(GcErase) for unions"))
619+
}
620+
}
572621
let mut impl_generics = generics.clone();
573622
impl_generics.params.push(GenericParam::Lifetime(parse_quote!('min)));
574-
impl_generics.params.push(GenericParam::Type(parse_quote!(S: ::zerogc::CollectorId)));
623+
if info.config.collector_id.is_none() {
624+
impl_generics.params.push(GenericParam::Type(parse_quote!(Id: ::zerogc::CollectorId)));
625+
}
575626
impl_generics.make_where_clause().predicates.extend(rewritten_restrictions);
576627
let (_, ty_generics, _) = generics.split_for_impl();
577628
let (impl_generics, _, where_clause) = impl_generics.split_for_impl();
578-
::proc_macro::Diagnostic::spanned(
579-
::proc_macro::Span::call_site(),
580-
::proc_macro::Level::Warning,
581-
"derive(GcErase) doesn't currently verify the correctness of its fields"
582-
).emit();
629+
let assert_erase = field_types.iter().map(|field_type| {
630+
let span = field_type.span();
631+
quote_spanned!(span => <#field_type as ::zerogc::GcErase<'min, #collector_id>>::assert_erase();)
632+
}).collect::<Vec<_>>();
583633
Ok(quote! {
584-
unsafe impl #impl_generics ::zerogc::GcErase<'min, S>
634+
unsafe impl #impl_generics ::zerogc::GcErase<'min, #collector_id>
585635
for #name #ty_generics #where_clause {
586636
type Erased = #name::<#(#rewritten_params),*>;
637+
638+
fn assert_erase() {
639+
#(#assert_erase)*
640+
}
587641
}
588642
})
589643
}
@@ -620,7 +674,13 @@ fn impl_rebrand_nop(target: &DeriveInput, info: &GcTypeInfo) -> Result<TokenStre
620674
}
621675
let mut impl_generics = generics.clone();
622676
impl_generics.params.push(GenericParam::Lifetime(parse_quote!('new_gc)));
623-
impl_generics.params.push(GenericParam::Type(parse_quote!(S: ::zerogc::CollectorId)));
677+
let collector_id = match info.config.collector_id {
678+
Some(ref id) => id.clone(),
679+
None => {
680+
impl_generics.params.push(GenericParam::Type(parse_quote!(Id: ::zerogc::CollectorId)));
681+
parse_quote!(Id)
682+
}
683+
};
624684
// Require that `Self: NullTrace`
625685
impl_generics.make_where_clause().predicates.push(WherePredicate::Type(PredicateType {
626686
lifetimes: None,
@@ -630,14 +690,8 @@ fn impl_rebrand_nop(target: &DeriveInput, info: &GcTypeInfo) -> Result<TokenStre
630690
}));
631691
let (_, ty_generics, _) = generics.split_for_impl();
632692
let (impl_generics, _, where_clause) = impl_generics.split_for_impl();
633-
::proc_macro::Diagnostic::spanned(
634-
::proc_macro::Span::call_site(),
635-
::proc_macro::Level::Note,
636-
// We know this is safe because we know that `Self: NullTrace`
637-
"derive(GcRebrand) is safe for NullTrace, unlike standard implementation"
638-
).emit();
639693
Ok(quote! {
640-
unsafe impl #impl_generics ::zerogc::GcRebrand<'new_gc, S>
694+
unsafe impl #impl_generics ::zerogc::GcRebrand<'new_gc, #collector_id>
641695
for #name #ty_generics #where_clause {
642696
// We can pass-through because we are NullTrace
643697
type Branded = Self;
@@ -649,14 +703,18 @@ fn impl_rebrand(target: &DeriveInput, info: &GcTypeInfo) -> Result<TokenStream,
649703
let mut generics: Generics = target.generics.clone();
650704
let mut rewritten_params = Vec::new();
651705
let mut rewritten_restrictions = Vec::new();
706+
let collector_id = match info.config.collector_id {
707+
Some(ref id) => id.clone(),
708+
None => parse_quote!(Id)
709+
};
652710
for param in &mut generics.params {
653711
let rewritten_param: GenericArgument;
654712
match param {
655713
GenericParam::Type(ref mut type_param) => {
656714
let original_bounds = type_param.bounds.iter().cloned().collect::<Vec<_>>();
657-
type_param.bounds.push(parse_quote!(::zerogc::GcRebrand<'new_gc, S>));
715+
type_param.bounds.push(parse_quote!(::zerogc::GcRebrand<'new_gc, #collector_id>));
658716
let param_name = &type_param.ident;
659-
let rewritten_type: Type = parse_quote!(<#param_name as ::zerogc::GcRebrand<'new_gc, S>>::Branded);
717+
let rewritten_type: Type = parse_quote!(<#param_name as ::zerogc::GcRebrand<'new_gc, #collector_id>>::Branded);
660718
rewritten_restrictions.push(WherePredicate::Type(PredicateType {
661719
lifetimes: None,
662720
bounded_ty: rewritten_type.clone(),
@@ -683,29 +741,52 @@ fn impl_rebrand(target: &DeriveInput, info: &GcTypeInfo) -> Result<TokenStream,
683741
}
684742
rewritten_params.push(rewritten_param);
685743
}
744+
let mut field_types = Vec::new();
745+
match target.data {
746+
Data::Struct(ref s) => {
747+
for f in &s.fields {
748+
field_types.push(f.ty.clone());
749+
}
750+
},
751+
Data::Enum(ref e) => {
752+
for variant in &e.variants {
753+
for f in &variant.fields {
754+
field_types.push(f.ty.clone());
755+
}
756+
}
757+
},
758+
Data::Union(_) => {
759+
return Err(Error::new(target.ident.span(), "Unable to derive(GcErase) for unions"))
760+
}
761+
}
686762
let mut impl_generics = generics.clone();
687763
impl_generics.params.push(GenericParam::Lifetime(parse_quote!('new_gc)));
688-
impl_generics.params.push(GenericParam::Type(parse_quote!(S: ::zerogc::CollectorId)));
764+
if info.config.collector_id.is_none() {
765+
impl_generics.params.push(GenericParam::Type(parse_quote!(Id: ::zerogc::CollectorId)));
766+
}
767+
let assert_rebrand = field_types.iter().map(|field_type| {
768+
let span = field_type.span();
769+
quote_spanned!(span => <#field_type as ::zerogc::GcRebrand<'new_gc, #collector_id>>::assert_rebrand();)
770+
}).collect::<Vec<_>>();
689771
impl_generics.make_where_clause().predicates.extend(rewritten_restrictions);
690772
let (_, ty_generics, _) = generics.split_for_impl();
691773
let (impl_generics, _, where_clause) = impl_generics.split_for_impl();
692-
::proc_macro::Diagnostic::spanned(
693-
::proc_macro::Span::call_site(),
694-
::proc_macro::Level::Warning,
695-
"derive(GcRebrand) doesn't currently verify the correctness of its fields"
696-
).emit();
697774
Ok(quote! {
698-
unsafe impl #impl_generics ::zerogc::GcRebrand<'new_gc, S>
775+
unsafe impl #impl_generics ::zerogc::GcRebrand<'new_gc, #collector_id>
699776
for #name #ty_generics #where_clause {
700777
type Branded = #name::<#(#rewritten_params),*>;
778+
779+
fn assert_rebrand() {
780+
#(#assert_rebrand)*
781+
}
701782
}
702783
})
703784
}
704785
fn impl_trace(target: &DeriveInput, info: &GcTypeInfo) -> Result<TokenStream, Error> {
705786
let name = &target.ident;
706787
let generics = add_trait_bounds_except(
707788
&target.generics, parse_quote!(zerogc::Trace),
708-
&info.config.ignore_params
789+
&info.config.ignore_params, None
709790
)?;
710791
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
711792
let field_types: Vec<&Type>;
@@ -787,9 +868,17 @@ fn impl_trace(target: &DeriveInput, info: &GcTypeInfo) -> Result<TokenStream, Er
787868
}
788869
fn impl_gc_safe(target: &DeriveInput, info: &GcTypeInfo) -> Result<TokenStream, Error> {
789870
let name = &target.ident;
871+
let collector_id = &info.config.collector_id;
790872
let generics = add_trait_bounds_except(
791873
&target.generics, parse_quote!(zerogc::GcSafe),
792-
&info.config.ignore_params
874+
&info.config.ignore_params,
875+
Some(&mut |other: &Ident| {
876+
if let Some(ref collector_id) = *collector_id {
877+
other == collector_id // -> ignore collector_id for GcSafe
878+
} else {
879+
false
880+
}
881+
})
793882
)?;
794883
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
795884
let field_types: Vec<&Type> = match target.data {
@@ -859,7 +948,7 @@ fn impl_nop_trace(target: &DeriveInput, info: &GcTypeInfo) -> Result<TokenStream
859948
let name = &target.ident;
860949
let generics = add_trait_bounds_except(
861950
&target.generics, parse_quote!(zerogc::Trace),
862-
&info.config.ignore_params
951+
&info.config.ignore_params, None
863952
)?;
864953
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
865954
let field_types: Vec<&Type>;
@@ -914,12 +1003,18 @@ fn impl_nop_trace(target: &DeriveInput, info: &GcTypeInfo) -> Result<TokenStream
9141003

9151004
fn add_trait_bounds_except(
9161005
generics: &Generics, bound: TypeParamBound,
917-
ignored_params: &HashSet<Ident>
1006+
ignored_params: &HashSet<Ident>,
1007+
mut extra_ignore: Option<&mut dyn FnMut(&Ident) -> bool>
9181008
) -> Result<Generics, Error> {
9191009
let mut actually_ignored_args = HashSet::<Ident>::new();
9201010
let generics = add_trait_bounds(
9211011
&generics, bound,
9221012
&mut |param: &TypeParam| {
1013+
if let Some(ref mut extra) = extra_ignore {
1014+
if extra(&param.ident) {
1015+
return true; // ignore (but don't add to set)
1016+
}
1017+
}
9231018
if ignored_params.contains(&param.ident) {
9241019
actually_ignored_args.insert(param.ident.clone());
9251020
true

libs/derive/tests/basic.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
1-
use zerogc::{Gc, CollectorId, Trace, GcSafe, NullTrace, dummy_impl};
1+
use zerogc::{Gc, CollectorId, Trace, GcSafe, NullTrace, dummy_impl::{self, DummyCollectorId}};
22

33
use zerogc_derive::Trace;
44

55
#[derive(Trace)]
6-
#[zerogc(ignore_params(Id))]
6+
#[zerogc(collector_id(DummyCollectorId))]
7+
pub struct SpecificCollector<'gc> {
8+
gc: Gc<'gc, i32, DummyCollectorId>,
9+
rec: Gc<'gc, SpecificCollector<'gc>, DummyCollectorId>
10+
}
11+
12+
#[derive(Trace)]
13+
#[zerogc(collector_id(Id))]
714
pub struct Basic<'gc, Id: CollectorId> {
815
parent: Option<Gc<'gc, Basic<'gc, Id>, Id>>,
916
children: Vec<Gc<'gc, Basic<'gc, Id>, Id>>,
1017
value: String
1118
}
1219

1320
#[derive(Copy, Clone, Trace)]
14-
#[zerogc(copy, ignore_params(Id))]
21+
#[zerogc(copy, collector_id(Id))]
1522
pub struct BasicCopy<'gc, Id: CollectorId> {
1623
test: i32,
1724
value: i32,
@@ -20,7 +27,7 @@ pub struct BasicCopy<'gc, Id: CollectorId> {
2027

2128

2229
#[derive(Copy, Clone, Trace)]
23-
#[zerogc(copy, ignore_params(Id))]
30+
#[zerogc(copy, collector_id(Id))]
2431
pub enum BasicEnum<'gc, Id: CollectorId> {
2532
Unit,
2633
Tuple(i32),

src/lib.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,12 @@ pub unsafe trait GcRebrand<'new_gc, Id: CollectorId>: Trace {
801801
/// This must have the same in-memory repr as `Self`,
802802
/// so that it's safe to transmute.
803803
type Branded: Trace + 'new_gc;
804+
805+
/// Assert this type can be rebranded
806+
///
807+
/// Only used by procedural derive
808+
#[doc(hidden)]
809+
fn assert_rebrand() {}
804810
}
805811
/// Indicates that it's safe to erase all GC lifetimes
806812
/// and change them to 'static (logically an 'unsafe)
@@ -815,6 +821,12 @@ pub unsafe trait GcErase<'a, Id: CollectorId>: Trace {
815821
/// This must have the same in-memory repr as `Self`,
816822
/// so that it's safe to transmute.
817823
type Erased: 'a;
824+
825+
/// Assert this type can be erased
826+
///
827+
/// Only used by procedural derive
828+
#[doc(hidden)]
829+
fn assert_erase() {}
818830
}
819831

820832
/// Indicates that a type can be traced by a garbage collector.

0 commit comments

Comments
 (0)