2
2
3
3
use proc_macro2:: TokenStream ;
4
4
use quote:: { format_ident, quote, ToTokens } ;
5
- use syn:: { parse_quote, spanned:: Spanned as _, token} ;
5
+ use syn:: {
6
+ parse_quote,
7
+ punctuated:: { self , Punctuated } ,
8
+ spanned:: Spanned as _,
9
+ } ;
6
10
7
11
use crate :: utils:: GenericsSearch ;
8
12
@@ -52,54 +56,69 @@ impl<'i> TryFrom<&'i syn::DeriveInput> for StructuralExpansion<'i> {
52
56
}
53
57
54
58
impl StructuralExpansion < ' _ > {
55
- /// Generates body of the [`PartialEq::eq()`] method implementation for this
56
- /// [`StructuralExpansion`].
57
- fn eq_body ( & self ) -> TokenStream {
58
- // Special case: empty enum.
59
+ /// Generates body of the [`PartialEq::eq()`]/[`PartialEq::ne()`] method implementation for this
60
+ /// [`StructuralExpansion`], if it's required .
61
+ fn body ( & self , eq : bool ) -> Option < TokenStream > {
62
+ // Special case: empty enum (also, no need for `ne()` method in this case) .
59
63
if self . variants . is_empty ( ) {
60
- return quote ! { match * self { } } ;
64
+ return eq . then ( || quote ! { match * self { } } ) ;
61
65
}
62
- // Special case: no fields to compare.
66
+
67
+ let no_fields_result = quote ! { #eq } ;
68
+
69
+ // Special case: no fields to compare (also, no need for `ne()` method in this case).
63
70
if self . variants . len ( ) == 1 && self . variants [ 0 ] . 1 . is_empty ( ) {
64
- return quote ! { true } ;
71
+ return eq . then_some ( no_fields_result ) ;
65
72
}
66
73
67
- let discriminants_eq = ( self . variants . len ( ) > 1 ) . then ( || {
74
+ let ( cmp, chain) = if eq {
75
+ ( quote ! { == } , quote ! { && } )
76
+ } else {
77
+ ( quote ! { != } , quote ! { || } )
78
+ } ;
79
+
80
+ let discriminants_cmp = ( self . variants . len ( ) > 1 ) . then ( || {
68
81
quote ! {
69
- derive_more:: core:: mem:: discriminant( self ) ==
82
+ derive_more:: core:: mem:: discriminant( self ) #cmp
70
83
derive_more:: core:: mem:: discriminant( __other)
71
84
}
72
85
} ) ;
73
86
74
- let matched_variants = self
87
+ let match_arms = self
75
88
. variants
76
89
. iter ( )
77
90
. filter_map ( |( variant, fields) | {
78
91
if fields. is_empty ( ) {
79
92
return None ;
80
93
}
94
+
81
95
let variant = variant. map ( |variant| quote ! { :: #variant } ) ;
82
96
let self_pattern = fields. arm_pattern ( "__self_" ) ;
83
97
let other_pattern = fields. arm_pattern ( "__other_" ) ;
84
- let val_eqs = ( 0 ..fields. len ( ) ) . map ( |num| {
85
- let self_val = format_ident ! ( "__self_{num}" ) ;
86
- let other_val = format_ident ! ( "__other_{num}" ) ;
87
- quote ! { #self_val == #other_val }
88
- } ) ;
98
+
99
+ let mut val_eqs = ( 0 ..fields. len ( ) )
100
+ . map ( |num| {
101
+ let self_val = format_ident ! ( "__self_{num}" ) ;
102
+ let other_val = format_ident ! ( "__other_{num}" ) ;
103
+ punctuated:: Pair :: Punctuated (
104
+ quote ! { #self_val #cmp #other_val } ,
105
+ & chain,
106
+ )
107
+ } )
108
+ . collect :: < Punctuated < TokenStream , _ > > ( ) ;
109
+ _ = val_eqs. pop_punct ( ) ;
110
+
89
111
Some ( quote ! {
90
- ( Self #variant #self_pattern, Self #variant #other_pattern) => {
91
- #( #val_eqs ) &&*
92
- }
112
+ ( Self #variant #self_pattern, Self #variant #other_pattern) => { #val_eqs }
93
113
} )
94
114
} )
95
115
. collect :: < Vec < _ > > ( ) ;
96
- let match_expr = ( !matched_variants. is_empty ( ) ) . then ( || {
97
- let always_true_arm =
98
- ( matched_variants. len ( ) != self . variants . len ( ) ) . then ( || {
99
- quote ! { _ => true }
100
- } ) ;
116
+ let match_expr = ( !match_arms. is_empty ( ) ) . then ( || {
117
+ let no_fields_arm = ( match_arms. len ( ) != self . variants . len ( ) ) . then ( || {
118
+ quote ! { _ => #no_fields_result }
119
+ } ) ;
101
120
let unreachable_arm = ( self . variants . len ( ) > 1
102
- && always_true_arm . is_none ( ) )
121
+ && no_fields_arm . is_none ( ) )
103
122
. then ( || {
104
123
quote ! {
105
124
// SAFETY: This arm is never reachable, but is required by the expanded
@@ -110,19 +129,25 @@ impl StructuralExpansion<'_> {
110
129
111
130
quote ! {
112
131
match ( self , __other) {
113
- #( #matched_variants , ) *
114
- #always_true_arm
132
+ #( #match_arms , ) *
133
+ #no_fields_arm
115
134
#unreachable_arm
116
135
}
117
136
}
118
137
} ) ;
119
138
120
- let and = ( discriminants_eq. is_some ( ) && match_expr. is_some ( ) )
121
- . then_some ( token:: AndAnd :: default ( ) ) ;
122
-
123
- quote ! {
124
- #discriminants_eq #and #match_expr
139
+ // If there is only `mem::discriminant()` comparison, there is no need to generate `ne()`
140
+ // method in the expansion, as its default implementation will do just fine.
141
+ if !eq && discriminants_cmp. is_some ( ) && match_expr. is_none ( ) {
142
+ return None ;
125
143
}
144
+
145
+ let chain =
146
+ ( discriminants_cmp. is_some ( ) && match_expr. is_some ( ) ) . then_some ( chain) ;
147
+
148
+ Some ( quote ! {
149
+ #discriminants_cmp #chain #match_expr
150
+ } )
126
151
}
127
152
}
128
153
@@ -143,17 +168,26 @@ impl ToTokens for StructuralExpansion<'_> {
143
168
}
144
169
let ( impl_generics, ty_generics, where_clause) = generics. split_for_impl ( ) ;
145
170
146
- let eq_body = self . eq_body ( ) ;
171
+ let eq_method = self . body ( true ) . map ( |body| {
172
+ quote ! {
173
+ #[ inline]
174
+ fn eq( & self , __other: & Self ) -> bool { #body }
175
+ }
176
+ } ) ;
177
+ let ne_method = self . body ( false ) . map ( |body| {
178
+ quote ! {
179
+ #[ inline]
180
+ fn ne( & self , __other: & Self ) -> bool { #body }
181
+ }
182
+ } ) ;
147
183
148
184
quote ! {
149
185
#[ automatically_derived]
150
186
impl #impl_generics derive_more:: core:: cmp:: PartialEq for #ty #ty_generics
151
187
#where_clause
152
188
{
153
- #[ inline]
154
- fn eq( & self , __other: & Self ) -> bool {
155
- #eq_body
156
- }
189
+ #eq_method
190
+ #ne_method
157
191
}
158
192
}
159
193
. to_tokens ( tokens) ;
0 commit comments