Skip to content

Commit 14a1696

Browse files
committed
Initial support for generics
Existing generic types and predicates are passed through unmodified.
1 parent 0b20578 commit 14a1696

File tree

2 files changed

+201
-17
lines changed

2 files changed

+201
-17
lines changed

diffus-derive-test/src/lib.rs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,4 +412,61 @@ mod test {
412412
&vec![string::Edit::Copy('a'), string::Edit::Insert('\''),]
413413
);
414414
}
415+
416+
mod generics {
417+
use diffus::Diffus;
418+
419+
pub trait Thing {
420+
type Foo;
421+
type Bar;
422+
}
423+
424+
pub struct ConcreteThing;
425+
426+
impl Thing for ConcreteThing {
427+
type Foo = String;
428+
type Bar = i64;
429+
}
430+
431+
#[derive(Diffus)]
432+
pub struct TestNamedStruct<A> where A: Thing {
433+
pub a: A::Foo,
434+
pub inner: i32,
435+
}
436+
437+
#[derive(Diffus)]
438+
pub enum TestTuple<A> where A: Thing {
439+
Hello {
440+
bar: A::Bar,
441+
},
442+
UnitVariant,
443+
TupleVariant(A::Bar, A::Bar),
444+
}
445+
446+
#[derive(Diffus)]
447+
pub struct TestUnnamedStruct<A>(pub A::Foo) where A: Thing;
448+
}
449+
450+
#[test]
451+
fn test() {
452+
use self::generics::{ConcreteThing, TestNamedStruct};
453+
use edit::string;
454+
455+
let a: TestNamedStruct<ConcreteThing> = TestNamedStruct { a: "a".to_string(), inner: 12 };
456+
let ap = TestNamedStruct { a: "a'".to_string(), inner: 13 };
457+
458+
let diff = a.diff(&ap);
459+
let actual_a = diff.change().unwrap().a.change().unwrap();
460+
let actual_inner = diff.change().unwrap().inner.change().unwrap();
461+
462+
assert_eq!(
463+
actual_a,
464+
&vec![string::Edit::Copy('a'), string::Edit::Insert('\''),]
465+
);
466+
467+
assert_eq!(
468+
actual_inner,
469+
&(&12, &13),
470+
);
471+
}
415472
}

diffus-derive/src/lib.rs

Lines changed: 144 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -132,18 +132,33 @@ fn input_lifetime(generics: &syn::Generics) -> Option<&syn::Lifetime> {
132132
lifetime
133133
}
134134

135+
struct Generics {
136+
ty_generic_params: syn::punctuated::Punctuated<syn::GenericParam, syn::token::Comma>,
137+
138+
edited_ty_generic_params: syn::punctuated::Punctuated<syn::GenericParam, syn::token::Comma>,
139+
edited_ty_where_clause: syn::WhereClause,
140+
141+
impl_diffable_generic_params: syn::punctuated::Punctuated<syn::GenericParam, syn::token::Comma>,
142+
impl_diffable_where_clause: syn::WhereClause,
143+
144+
impl_lifetime: syn::Lifetime,
145+
}
146+
135147
#[proc_macro_derive(Diffus)]
136148
pub fn derive_diffus(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
137149
let input: syn::DeriveInput = syn::parse2(proc_macro2::TokenStream::from(input)).unwrap();
138150

139151
let ident = &input.ident;
140152
let vis = &input.vis;
141-
let where_clause = &input.generics.where_clause;
153+
142154
let edited_ident = syn::parse_str::<syn::Path>(&format!("Edited{}", ident)).unwrap();
143155

144-
let data_lifetime = input_lifetime(&input.generics);
145-
let default_lifetime = syn::parse_str::<syn::Lifetime>("'diffus_a").unwrap();
146-
let impl_lifetime = data_lifetime.unwrap_or(&default_lifetime);
156+
let Generics {
157+
ty_generic_params,
158+
edited_ty_generic_params, edited_ty_where_clause,
159+
impl_diffable_generic_params, impl_diffable_where_clause,
160+
impl_lifetime,
161+
} = Generics::new(&input.generics, &input.data);
147162

148163
#[cfg(feature = "serialize-impl")]
149164
let derive_serialize = Some(quote! { #[derive(serde::Serialize)] });
@@ -182,8 +197,8 @@ pub fn derive_diffus(input: proc_macro::TokenStream) -> proc_macro::TokenStream
182197
}
183198
});
184199

185-
let unit_enum_impl_lifetime = if has_non_unit_variant {
186-
Some(impl_lifetime.clone())
200+
let unit_enum_generic_params = if has_non_unit_variant {
201+
Some(edited_ty_generic_params.clone())
187202
} else {
188203
None
189204
};
@@ -262,12 +277,12 @@ pub fn derive_diffus(input: proc_macro::TokenStream) -> proc_macro::TokenStream
262277

263278
quote! {
264279
#derive_serialize
265-
#vis enum #edited_ident <#unit_enum_impl_lifetime> where #where_clause {
280+
#vis enum #edited_ident <#unit_enum_generic_params> #edited_ty_where_clause {
266281
#(#edit_variants),*
267282
}
268283

269-
impl<#impl_lifetime> diffus::Diffable<#impl_lifetime> for #ident <#data_lifetime> where #where_clause {
270-
type Diff = diffus::edit::enm::Edit<#impl_lifetime, Self, #edited_ident <#unit_enum_impl_lifetime>>;
284+
impl<#impl_diffable_generic_params> diffus::Diffable<#impl_lifetime> for #ident <#ty_generic_params> #impl_diffable_where_clause {
285+
type Diff = diffus::edit::enm::Edit<#impl_lifetime, Self, #edited_ident <#unit_enum_generic_params>>;
271286

272287
fn diff(&#impl_lifetime self, other: &#impl_lifetime Self) -> diffus::edit::Edit<#impl_lifetime, Self> {
273288
match (self, other) {
@@ -290,12 +305,12 @@ pub fn derive_diffus(input: proc_macro::TokenStream) -> proc_macro::TokenStream
290305
syn::Fields::Named(_) => {
291306
quote! {
292307
#derive_serialize
293-
#vis struct #edited_ident<#impl_lifetime> where #where_clause {
308+
#vis struct #edited_ident<#edited_ty_generic_params> #edited_ty_where_clause {
294309
#edit_fields
295310
}
296311

297-
impl<#impl_lifetime> diffus::Diffable<#impl_lifetime> for #ident <#data_lifetime> where #where_clause {
298-
type Diff = #edited_ident<#impl_lifetime>;
312+
impl<#impl_diffable_generic_params> diffus::Diffable<#impl_lifetime> for #ident <#ty_generic_params> #impl_diffable_where_clause {
313+
type Diff = #edited_ident<#edited_ty_generic_params>;
299314

300315
fn diff(&#impl_lifetime self, other: &#impl_lifetime Self) -> diffus::edit::Edit<#impl_lifetime, Self> {
301316
match ( #field_diffs ) {
@@ -311,10 +326,10 @@ pub fn derive_diffus(input: proc_macro::TokenStream) -> proc_macro::TokenStream
311326
syn::Fields::Unnamed(_) => {
312327
quote! {
313328
#derive_serialize
314-
#vis struct #edited_ident<#impl_lifetime> ( #edit_fields ) where #where_clause;
329+
#vis struct #edited_ident<#edited_ty_generic_params> ( #edit_fields ) #edited_ty_where_clause;
315330

316-
impl<#impl_lifetime> diffus::Diffable<#impl_lifetime> for #ident <#data_lifetime> where #where_clause {
317-
type Diff = #edited_ident<#impl_lifetime>;
331+
impl<#impl_diffable_generic_params> diffus::Diffable<#impl_lifetime> for #ident <#ty_generic_params> #impl_diffable_where_clause {
332+
type Diff = #edited_ident<#edited_ty_generic_params>;
318333

319334
fn diff(&#impl_lifetime self, other: &#impl_lifetime Self) -> diffus::edit::Edit<#impl_lifetime, Self> {
320335
match ( #field_diffs ) {
@@ -330,9 +345,9 @@ pub fn derive_diffus(input: proc_macro::TokenStream) -> proc_macro::TokenStream
330345
syn::Fields::Unit => {
331346
quote! {
332347
#derive_serialize
333-
#vis struct #edited_ident< > where #where_clause;
348+
#vis struct #edited_ident< > #edited_ty_where_clause;
334349

335-
impl<#impl_lifetime> diffus::Diffable<#impl_lifetime> for #ident< > where #where_clause {
350+
impl<#impl_lifetime> diffus::Diffable<#impl_lifetime> for #ident< > #impl_diffable_where_clause {
336351
type Diff = #edited_ident;
337352

338353
fn diff(&#impl_lifetime self, other: &#impl_lifetime Self) -> diffus::edit::Edit<#impl_lifetime, Self> {
@@ -346,3 +361,115 @@ pub fn derive_diffus(input: proc_macro::TokenStream) -> proc_macro::TokenStream
346361
syn::Data::Union(_) => panic!("union type not supported yet"),
347362
})
348363
}
364+
365+
impl Generics {
366+
pub fn new(
367+
input_generics: &syn::Generics,
368+
data: &syn::Data,
369+
) -> Self {
370+
let input_generic_params = &input_generics.params;
371+
let input_where_clause = &input_generics.where_clause;
372+
let empty_where_clause = input_generics.clone().make_where_clause().clone();
373+
374+
let generic_types_used = Self::collect_generic_types_used(input_generics, data);
375+
376+
let ty_generic_params = input_generic_params.clone();
377+
378+
let mut edited_ty_generic_params = ty_generic_params.clone();
379+
let mut edited_ty_where_clause = input_where_clause.clone().unwrap_or(empty_where_clause.clone());
380+
381+
let mut impl_diffable_generic_params = input_generic_params.clone();
382+
let mut impl_diffable_where_clause = input_where_clause.clone().unwrap_or(empty_where_clause);
383+
384+
let explicit_data_lifetime = input_lifetime(input_generics);
385+
let impl_lifetime = explicit_data_lifetime.cloned().unwrap_or_else(|| {
386+
let default_lifetime = syn::parse_str::<syn::Lifetime>("'diffus_a").unwrap();
387+
388+
// Add the lifetime into the generics lists.
389+
impl_diffable_generic_params.insert(0, syn::GenericParam::Lifetime(syn::LifetimeDef::new(default_lifetime.clone())));
390+
edited_ty_generic_params.insert(0, syn::GenericParam::Lifetime(syn::LifetimeDef::new(default_lifetime.clone())));
391+
392+
default_lifetime.clone()
393+
});
394+
395+
// Ensure that all generic types that exist live for as long as the diffus lifetime.
396+
impl_diffable_where_clause.predicates.extend(input_generics.type_params().map(|type_param| {
397+
let where_predicate = quote!(#type_param : #impl_lifetime);
398+
let where_predicate: syn::WherePredicate = syn::parse(where_predicate.into()).unwrap();
399+
where_predicate
400+
}));
401+
402+
// Ensure that all generic types actually used are diffable and live for as long as the
403+
// diffus lifetime.
404+
for generic_ty_path in generic_types_used {
405+
let where_predicates = vec![
406+
// quote!(<#generic_ty_path as diffus::Diffable<#impl_lifetime>>::Diff : diffus::Diffable<#impl_lifetime> + #impl_lifetime),
407+
quote!(#generic_ty_path : diffus::Diffable<#impl_lifetime> + #impl_lifetime),
408+
];
409+
let where_predicates: Vec<_> = where_predicates.into_iter().map(|wp| syn::parse::<syn::WherePredicate>(wp.into()).unwrap()).collect();
410+
411+
impl_diffable_where_clause.predicates.extend(where_predicates.clone());
412+
edited_ty_where_clause.predicates.extend(where_predicates.clone());
413+
}
414+
415+
Generics {
416+
ty_generic_params,
417+
edited_ty_generic_params, edited_ty_where_clause,
418+
impl_diffable_generic_params, impl_diffable_where_clause,
419+
impl_lifetime,
420+
}
421+
}
422+
423+
/// Collects all of the generic types used in a type including associated types.
424+
fn collect_generic_types_used(
425+
input_generics: &syn::Generics,
426+
data: &syn::Data,
427+
) -> Vec<syn::Path> {
428+
let all_possible_fields: Vec<&syn::Fields> = match *data {
429+
syn::Data::Struct(ref s) => vec![&s.fields],
430+
syn::Data::Enum(ref e) => e.variants.iter().map(|v| &v.fields).collect(),
431+
syn::Data::Union(..) => Vec::new(), // unimplemented
432+
};
433+
434+
let all_possible_types: Vec<&syn::Type> = all_possible_fields.into_iter().flat_map(|fields| match fields {
435+
syn::Fields::Named(ref fields) => fields.named.iter().map(|f| &f.ty).collect(),
436+
syn::Fields::Unnamed(ref fields) => fields.unnamed.iter().map(|f| &f.ty).collect(),
437+
syn::Fields::Unit => Vec::new(),
438+
}).collect();
439+
440+
let mut generic_types_used = Vec::new();
441+
let mut remaining_types_to_check = all_possible_types.clone();
442+
443+
while let Some(type_to_check) = remaining_types_to_check.pop() {
444+
match *type_to_check {
445+
syn::Type::Path(ref path) => {
446+
if let Some(first_segment) = path.path.segments.first().map(|s| &s.ident) {
447+
let first_segment: syn::Ident = first_segment.clone().into();
448+
449+
if input_generics.type_params().any(|type_param| type_param.ident == first_segment) {
450+
generic_types_used.push(path.path.clone());
451+
}
452+
}
453+
},
454+
455+
syn::Type::Array(ref array) => remaining_types_to_check.push(&array.elem),
456+
syn::Type::Group(ref group) => remaining_types_to_check.push(&group.elem),
457+
syn::Type::Paren(ref paren) => remaining_types_to_check.push(&paren.elem),
458+
syn::Type::Ptr(ref ptr) => remaining_types_to_check.push(&ptr.elem),
459+
syn::Type::Reference(ref reference) => remaining_types_to_check.push(&reference.elem),
460+
syn::Type::Slice(ref slice) => remaining_types_to_check.push(&slice.elem),
461+
syn::Type::Tuple(ref tuple) => remaining_types_to_check.extend(tuple.elems.iter()),
462+
syn::Type::Verbatim(..) |
463+
syn::Type::ImplTrait(..) |
464+
syn::Type::Infer(..) |
465+
syn::Type::Macro(..) |
466+
syn::Type::Never(..) |
467+
syn::Type::TraitObject(..) |
468+
syn::Type::BareFn(..) => (),
469+
_ => (), // unknown/unsupported type
470+
}
471+
}
472+
473+
generic_types_used
474+
}
475+
}

0 commit comments

Comments
 (0)