Skip to content

Commit f2b7d0d

Browse files
committed
Skip trivial fields in derived traversables
1 parent 6c218fc commit f2b7d0d

File tree

6 files changed

+360
-111
lines changed

6 files changed

+360
-111
lines changed

compiler/rustc_macros/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,6 @@ quote = "1"
1313
syn = { version = "2.0.9", features = ["full"] }
1414
synstructure = "0.13.0"
1515
# tidy-alphabetical-end
16+
17+
[dev-dependencies]
18+
syn = { version = "*", features = ["visit-mut"] }

compiler/rustc_macros/src/lib.rs

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,15 @@ decl_derive!([TyEncodable] => serialize::type_encodable_derive);
7474
decl_derive!([MetadataDecodable] => serialize::meta_decodable_derive);
7575
decl_derive!([MetadataEncodable] => serialize::meta_encodable_derive);
7676
decl_derive!(
77-
[TypeFoldable, attributes(type_foldable)] =>
77+
[TypeFoldable] =>
7878
/// Derives `TypeFoldable` for the annotated `struct` or `enum` (`union` is not supported).
7979
///
80-
/// Folds will produce a value of the same struct or enum variant as the input, with each field
81-
/// respectively folded (in definition order) using the `TypeFoldable` implementation for its
82-
/// type. However, if a field of a struct or of an enum variant is annotated with
83-
/// `#[type_foldable(identity)]` then that field will retain its incumbent value (and its type
84-
/// is not required to implement `TypeFoldable`). However use of this attribute is dangerous
85-
/// and should be used with extreme caution: should the type of the annotated field contain
86-
/// (now or in the future) a type that is of interest to a folder, it will not get folded (which
87-
/// may result in unexpected, hard-to-track bugs that could result in unsoundness).
80+
/// Folds will produce a value of the same struct or enum variant as the input, with trivial
81+
/// fields unchanged and all non-trivial fields respectively folded (in definition order) using
82+
/// the `TypeFoldable` implementation for its type. A field of type `T` is "trivial" if `T`
83+
/// both does not reference any generic type parameters and either
84+
/// - does not reference any `'tcx` lifetime parameter; or
85+
/// - does not contain anything that could be of interest to folders.
8886
///
8987
/// If the annotated type has a `'tcx` lifetime parameter, then that will be used as the
9088
/// lifetime for the type context/interner; otherwise the lifetime of the type context/interner
@@ -100,17 +98,15 @@ decl_derive!(
10098
traversable::traversable_derive::<traversable::Foldable>
10199
);
102100
decl_derive!(
103-
[TypeVisitable, attributes(type_visitable)] =>
101+
[TypeVisitable] =>
104102
/// Derives `TypeVisitable` for the annotated `struct` or `enum` (`union` is not supported).
105103
///
106-
/// Each field of the struct or enum variant will be visited (in definition order) using the
107-
/// `TypeVisitable` implementation for its type. However, if a field of a struct or of an enum
108-
/// variant is annotated with `#[type_visitable(ignore)]` then that field will not be visited
109-
/// (and its type is not required to implement `TypeVisitable`). However use of this attribute
110-
/// is dangerous and should be used with extreme caution: should the type of the annotated
111-
/// field (now or in the future) a type that is of interest to a visitor, it will not get
112-
/// visited (which may result in unexpected, hard-to-track bugs that could result in
113-
/// unsoundness).
104+
/// Each non-trivial field of the struct or enum variant will be visited (in definition order)
105+
/// using the `TypeVisitable` implementation for its type; trivial fields will be ignored. A
106+
/// field of type `T` is "trivial" if `T` both does not reference any generic type parameters
107+
/// and either
108+
/// - does not reference any `'tcx` lifetime parameter; or
109+
/// - does not contain anything that could be of interest to visitors.
114110
///
115111
/// If the annotated type has a `'tcx` lifetime parameter, then that will be used as the
116112
/// lifetime for the type context/interner; otherwise the lifetime of the type context/interner
Lines changed: 209 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,126 +1,247 @@
1-
use proc_macro2::TokenStream;
1+
use proc_macro2::{Ident, Span, TokenStream};
22
use quote::{quote, ToTokens};
3-
use syn::parse_quote;
3+
use std::collections::HashSet;
4+
use syn::{parse_quote, visit, Field, Generics, Lifetime};
5+
6+
#[cfg(test)]
7+
mod tests;
8+
9+
/// Generate a type parameter with the given `suffix` that does not conflict with
10+
/// any of the `existing` generics.
11+
fn gen_param(suffix: impl ToString, existing: &Generics) -> Ident {
12+
let mut suffix = suffix.to_string();
13+
while existing.type_params().any(|t| t.ident == suffix) {
14+
suffix.insert(0, '_');
15+
}
16+
Ident::new(&suffix, Span::call_site())
17+
}
18+
19+
#[derive(Clone, Copy, PartialEq)]
20+
enum Type {
21+
/// Describes a type that is not parameterised by the interner, and therefore cannot
22+
/// be of any interest to traversers.
23+
Trivial,
24+
25+
/// Describes a type that is parameterised by the interner lifetime `'tcx` but that is
26+
/// otherwise not generic.
27+
NotGeneric,
28+
29+
/// Describes a type that is generic.
30+
Generic,
31+
}
32+
use Type::*;
33+
34+
pub struct Interner<'a>(Option<&'a Lifetime>);
35+
36+
impl<'a> Interner<'a> {
37+
/// Return the `TyCtxt` interner for the given `structure`.
38+
///
39+
/// If the input represented by `structure` has a `'tcx` lifetime parameter, then that will be used
40+
/// used as the lifetime of the `TyCtxt`. Otherwise a `'tcx` lifetime parameter that is unrelated
41+
/// to the input will be used.
42+
fn resolve(generics: &'a Generics) -> Self {
43+
Self(
44+
generics
45+
.lifetimes()
46+
.find_map(|def| (def.lifetime.ident == "tcx").then_some(&def.lifetime)),
47+
)
48+
}
49+
}
50+
51+
impl ToTokens for Interner<'_> {
52+
fn to_tokens(&self, tokens: &mut TokenStream) {
53+
let default = parse_quote! { 'tcx };
54+
let lt = self.0.unwrap_or(&default);
55+
tokens.extend(quote! { ::rustc_middle::ty::TyCtxt<#lt> });
56+
}
57+
}
458

559
pub struct Foldable;
660
pub struct Visitable;
761

862
/// An abstraction over traversable traits.
963
pub trait Traversable {
10-
/// The trait that this `Traversable` represents.
11-
fn traversable() -> TokenStream;
12-
13-
/// The `match` arms for a traversal of this type.
14-
fn arms(structure: &mut synstructure::Structure<'_>) -> TokenStream;
15-
16-
/// The body of an implementation given the match `arms`.
17-
fn impl_body(arms: impl ToTokens) -> TokenStream;
64+
/// The trait that this `Traversable` represents, parameterised by `interner`.
65+
fn traversable(interner: &Interner<'_>) -> TokenStream;
66+
67+
/// Any supertraits that this trait is required to implement.
68+
fn supertraits(interner: &Interner<'_>) -> TokenStream;
69+
70+
/// A (`noop`) traversal of this trait upon the `bind` expression.
71+
fn traverse(bind: TokenStream, noop: bool) -> TokenStream;
72+
73+
/// A `match` arm for `variant`, where `f` generates the tokens for each binding.
74+
fn arm(
75+
variant: &synstructure::VariantInfo<'_>,
76+
f: impl FnMut(&synstructure::BindingInfo<'_>) -> TokenStream,
77+
) -> TokenStream;
78+
79+
/// The body of an implementation given the `interner`, `traverser` and match expression `body`.
80+
fn impl_body(
81+
interner: Interner<'_>,
82+
traverser: impl ToTokens,
83+
body: impl ToTokens,
84+
) -> TokenStream;
1885
}
1986

2087
impl Traversable for Foldable {
21-
fn traversable() -> TokenStream {
22-
quote! { ::rustc_middle::ty::fold::TypeFoldable<::rustc_middle::ty::TyCtxt<'tcx>> }
23-
}
24-
fn arms(structure: &mut synstructure::Structure<'_>) -> TokenStream {
25-
structure.each_variant(|vi| {
26-
let bindings = vi.bindings();
27-
vi.construct(|_, index| {
28-
let bind = &bindings[index];
29-
30-
let mut fixed = false;
31-
32-
// retain value of fields with #[type_foldable(identity)]
33-
bind.ast().attrs.iter().for_each(|x| {
34-
if !x.path().is_ident("type_foldable") {
35-
return;
36-
}
37-
let _ = x.parse_nested_meta(|nested| {
38-
if nested.path.is_ident("identity") {
39-
fixed = true;
40-
}
41-
Ok(())
42-
});
43-
});
44-
45-
if fixed {
46-
bind.to_token_stream()
47-
} else {
48-
quote! {
49-
::rustc_middle::ty::fold::TypeFoldable::try_fold_with(#bind, __folder)?
50-
}
51-
}
52-
})
53-
})
88+
fn traversable(interner: &Interner<'_>) -> TokenStream {
89+
quote! { ::rustc_middle::ty::fold::TypeFoldable<#interner> }
90+
}
91+
fn supertraits(interner: &Interner<'_>) -> TokenStream {
92+
Visitable::traversable(interner)
5493
}
55-
fn impl_body(arms: impl ToTokens) -> TokenStream {
94+
fn traverse(bind: TokenStream, noop: bool) -> TokenStream {
95+
if noop {
96+
bind
97+
} else {
98+
quote! { ::rustc_middle::ty::fold::TypeFoldable::try_fold_with(#bind, folder)? }
99+
}
100+
}
101+
fn arm(
102+
variant: &synstructure::VariantInfo<'_>,
103+
mut f: impl FnMut(&synstructure::BindingInfo<'_>) -> TokenStream,
104+
) -> TokenStream {
105+
let bindings = variant.bindings();
106+
variant.construct(|_, index| f(&bindings[index]))
107+
}
108+
fn impl_body(
109+
interner: Interner<'_>,
110+
traverser: impl ToTokens,
111+
body: impl ToTokens,
112+
) -> TokenStream {
56113
quote! {
57-
fn try_fold_with<__F: ::rustc_middle::ty::fold::FallibleTypeFolder<::rustc_middle::ty::TyCtxt<'tcx>>>(
114+
fn try_fold_with<#traverser: ::rustc_middle::ty::fold::FallibleTypeFolder<#interner>>(
58115
self,
59-
__folder: &mut __F
60-
) -> ::core::result::Result<Self, __F::Error> {
61-
::core::result::Result::Ok(match self { #arms })
116+
folder: &mut #traverser
117+
) -> ::core::result::Result<Self, #traverser::Error> {
118+
::core::result::Result::Ok(#body)
62119
}
63120
}
64121
}
65122
}
66123

67124
impl Traversable for Visitable {
68-
fn traversable() -> TokenStream {
69-
quote! { ::rustc_middle::ty::visit::TypeVisitable<::rustc_middle::ty::TyCtxt<'tcx>> }
125+
fn traversable(interner: &Interner<'_>) -> TokenStream {
126+
quote! { ::rustc_middle::ty::visit::TypeVisitable<#interner> }
70127
}
71-
fn arms(structure: &mut synstructure::Structure<'_>) -> TokenStream {
72-
// ignore fields with #[type_visitable(ignore)]
73-
structure.filter(|bi| {
74-
let mut ignored = false;
75-
76-
bi.ast().attrs.iter().for_each(|attr| {
77-
if !attr.path().is_ident("type_visitable") {
78-
return;
79-
}
80-
let _ = attr.parse_nested_meta(|nested| {
81-
if nested.path.is_ident("ignore") {
82-
ignored = true;
83-
}
84-
Ok(())
85-
});
86-
});
87-
88-
!ignored
89-
});
90-
91-
structure.each(|bind| {
92-
quote! {
93-
::rustc_middle::ty::visit::TypeVisitable::visit_with(#bind, __visitor)?;
94-
}
95-
})
128+
fn supertraits(_: &Interner<'_>) -> TokenStream {
129+
quote! { ::core::clone::Clone + ::core::fmt::Debug }
130+
}
131+
fn traverse(bind: TokenStream, noop: bool) -> TokenStream {
132+
if noop {
133+
quote! {}
134+
} else {
135+
quote! { ::rustc_middle::ty::visit::TypeVisitable::visit_with(#bind, visitor)?; }
136+
}
96137
}
97-
fn impl_body(arms: impl ToTokens) -> TokenStream {
138+
fn arm(
139+
variant: &synstructure::VariantInfo<'_>,
140+
f: impl FnMut(&synstructure::BindingInfo<'_>) -> TokenStream,
141+
) -> TokenStream {
142+
variant.bindings().iter().map(f).collect()
143+
}
144+
fn impl_body(
145+
interner: Interner<'_>,
146+
traverser: impl ToTokens,
147+
body: impl ToTokens,
148+
) -> TokenStream {
98149
quote! {
99-
fn visit_with<__V: ::rustc_middle::ty::visit::TypeVisitor<::rustc_middle::ty::TyCtxt<'tcx>>>(
150+
fn visit_with<#traverser: ::rustc_middle::ty::visit::TypeVisitor<#interner>>(
100151
&self,
101-
__visitor: &mut __V
102-
) -> ::std::ops::ControlFlow<__V::BreakTy> {
103-
match self { #arms }
104-
::std::ops::ControlFlow::Continue(())
152+
visitor: &mut #traverser
153+
) -> ::core::ops::ControlFlow<#traverser::BreakTy> {
154+
#body
155+
::core::ops::ControlFlow::Continue(())
105156
}
106157
}
107158
}
108159
}
109160

161+
impl Interner<'_> {
162+
/// We consider a type to be internable if it references either a generic type parameter or,
163+
/// if the interner is `TyCtxt<'tcx>`, the `'tcx` lifetime.
164+
fn type_of<'a>(
165+
&self,
166+
referenced_ty_params: &[&Ident],
167+
fields: impl IntoIterator<Item = &'a Field>,
168+
) -> Type {
169+
use visit::Visit;
170+
171+
struct Info<'a> {
172+
interner: &'a Lifetime,
173+
contains_interner: bool,
174+
}
175+
176+
impl Visit<'_> for Info<'_> {
177+
fn visit_lifetime(&mut self, i: &Lifetime) {
178+
if i == self.interner {
179+
self.contains_interner = true;
180+
} else {
181+
visit::visit_lifetime(self, i)
182+
}
183+
}
184+
}
185+
186+
if !referenced_ty_params.is_empty() {
187+
Generic
188+
} else if let Some(interner) = &self.0 && fields.into_iter().any(|field| {
189+
let mut info = Info { interner, contains_interner: false };
190+
info.visit_type(&field.ty);
191+
info.contains_interner
192+
}) {
193+
NotGeneric
194+
} else {
195+
Trivial
196+
}
197+
}
198+
}
199+
110200
pub fn traversable_derive<T: Traversable>(
111201
mut structure: synstructure::Structure<'_>,
112202
) -> TokenStream {
113-
if let syn::Data::Union(_) = structure.ast().data {
114-
panic!("cannot derive on union")
115-
}
203+
let ast = structure.ast();
204+
205+
let interner = Interner::resolve(&ast.generics);
206+
let traverser = gen_param("T", &ast.generics);
207+
let traversable = T::traversable(&interner);
116208

117-
structure.add_bounds(synstructure::AddBounds::Generics);
209+
structure.underscore_const(true);
210+
structure.add_bounds(synstructure::AddBounds::None);
118211
structure.bind_with(|_| synstructure::BindStyle::Move);
119212

120-
if !structure.ast().generics.lifetimes().any(|lt| lt.lifetime.ident == "tcx") {
213+
if interner.0.is_none() {
121214
structure.add_impl_generic(parse_quote! { 'tcx });
122215
}
123216

124-
let arms = T::arms(&mut structure);
125-
structure.bound_impl(T::traversable(), T::impl_body(arms))
217+
// If our derived implementation will be generic over the traversable type, then we must
218+
// constrain it to only those generic combinations that satisfy the traversable trait's
219+
// supertraits.
220+
if ast.generics.type_params().next().is_some() {
221+
let supertraits = T::supertraits(&interner);
222+
structure.add_where_predicate(parse_quote! { Self: #supertraits });
223+
}
224+
225+
// We add predicates to each generic field type, rather than to our generic type parameters.
226+
// This results in a "perfect derive", but it can result in trait solver cycles if any type
227+
// parameters are involved in recursive type definitions; fortunately that is not the case (yet).
228+
let mut predicates = HashSet::new();
229+
let arms = structure.each_variant(|variant| {
230+
T::arm(variant, |bind| {
231+
let ast = bind.ast();
232+
let field_ty = interner.type_of(&bind.referenced_ty_params(), [ast]);
233+
if field_ty == Generic {
234+
predicates.insert(ast.ty.clone());
235+
}
236+
T::traverse(bind.into_token_stream(), field_ty == Trivial)
237+
})
238+
});
239+
// the order in which `where` predicates appear in rust source is irrelevant
240+
#[allow(rustc::potential_query_instability)]
241+
for ty in predicates {
242+
structure.add_where_predicate(parse_quote! { #ty: #traversable });
243+
}
244+
let body = quote! { match self { #arms } };
245+
246+
structure.bound_impl(traversable, T::impl_body(interner, traverser, body))
126247
}

0 commit comments

Comments
 (0)