Skip to content

Commit b982f90

Browse files
authored
Generate PartialEq::ne() (#475, #163)
Part of #163 ## Synopsis `derive(PartialEq)` in `std` generates only `PartialEq::eq()` method implementation in its expansion. This silently kills any potential performance benefits if the underlying types implement `PartialEq::ne()` more efficiently than the default implementation (`!PartialEq::eq()`). ## Solution Generate `PartialEq::ne()` implementation always as well, by structurally calling `!=` operator instead of `==`. ## Additionally Unfortunately, the `assert_ne!()` macro also doesn't call `PartialEq::ne()` or `!=` operator inside its expansion, and only negates the equality check. Found this when running tests against a purposely incorrect `PartialEq::ne()` implementation. That's why the `assert_ne!()` macro is redefined in `partial_eq` tests as `asset!(left != right)`.
1 parent f632ac6 commit b982f90

File tree

4 files changed

+111
-41
lines changed

4 files changed

+111
-41
lines changed

CHANGELOG.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
1919
([#459](https://github.com/JelteF/derive_more/pull/459))
2020
- Support structs with no fields in `FromStr` derive.
2121
([#469](https://github.com/JelteF/derive_more/pull/469))
22-
- Add `PartialEq` derive similar to `std`'s one, but considering generics correctly.
23-
([#473](https://github.com/JelteF/derive_more/pull/473))
22+
- Add `PartialEq` derive similar to `std`'s one, but considering generics correctly,
23+
and implementing `ne()` method as well.
24+
([#473](https://github.com/JelteF/derive_more/pull/473),
25+
[#475](https://github.com/JelteF/derive_more/pull/475))
2426
- Proxy-pass `#[allow]`/`#[expect]` attributes of the type in `Constructor` derive.
2527
([#477](https://github.com/JelteF/derive_more/pull/477))
2628

impl/doc/eq.md

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ their type structure.
99
## Structural equality
1010

1111
Deriving `PartialEq` for enums/structs works in a similar way to the one in `std`,
12-
by comparing all the available fields, but, in the contrast, does not overconstrain
13-
generic parameters.
12+
by comparing all the available fields, but, in the contrast:
13+
1. Does not overconstrain generic parameters.
14+
2. Implements `PartialEq::ne()` method as well, to propagate possible efficient
15+
implementations of this method from the underlying types.
1416

1517

1618
### Structs
@@ -72,6 +74,15 @@ where
7274
}
7375
}
7476
}
77+
78+
#[inline]
79+
fn ne(&self, other: &Self) -> bool {
80+
match (self, other) {
81+
(Self { a: self_0, b: self_1, c: self_2 }, Self { a: other_0, b: other_1, c: other_2 }) => {
82+
self_0 != other_0 || self_1 != other_1 || self_2 != other_2
83+
}
84+
}
85+
}
7586
}
7687
```
7788

@@ -143,5 +154,16 @@ where
143154
_ => unsafe { std::hint::unreachable_unchecked() },
144155
}
145156
}
157+
158+
#[inline]
159+
fn ne(&self, other: &Self) -> bool {
160+
std::mem::discriminant(self) != std::mem::discriminant(other) ||
161+
match (self, other) {
162+
(Self::A(self_0), Self::A(other_0)) => { self_0 != other_0 }
163+
(Self::B { b: self_0 }, Self::B { b: other_0 }) => { self_0 != other_0 }
164+
(Self::C(self_0), Self::C(other_0)) => { self_0 != other_0 }
165+
_ => unsafe { std::hint::unreachable_unchecked() },
166+
}
167+
}
146168
}
147169
```

impl/src/partial_eq.rs

Lines changed: 71 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
33
use proc_macro2::TokenStream;
44
use quote::{format_ident, quote, ToTokens};
5-
use syn::{parse_quote, spanned::Spanned as _, token};
5+
use syn::{
6+
parse_quote,
7+
punctuated::{self, Punctuated},
8+
spanned::Spanned as _,
9+
};
610

711
use crate::utils::GenericsSearch;
812

@@ -52,54 +56,69 @@ impl<'i> TryFrom<&'i syn::DeriveInput> for StructuralExpansion<'i> {
5256
}
5357

5458
impl StructuralExpansion<'_> {
55-
/// Generates body of the [`PartialEq::eq()`] method implementation for this
56-
/// [`StructuralExpansion`].
57-
fn eq_body(&self) -> TokenStream {
58-
// Special case: empty enum.
59+
/// Generates body of the [`PartialEq::eq()`]/[`PartialEq::ne()`] method implementation for this
60+
/// [`StructuralExpansion`], if it's required.
61+
fn body(&self, eq: bool) -> Option<TokenStream> {
62+
// Special case: empty enum (also, no need for `ne()` method in this case).
5963
if self.variants.is_empty() {
60-
return quote! { match *self {} };
64+
return eq.then(|| quote! { match *self {} });
6165
}
62-
// Special case: no fields to compare.
66+
67+
let no_fields_result = quote! { #eq };
68+
69+
// Special case: no fields to compare (also, no need for `ne()` method in this case).
6370
if self.variants.len() == 1 && self.variants[0].1.is_empty() {
64-
return quote! { true };
71+
return eq.then_some(no_fields_result);
6572
}
6673

67-
let discriminants_eq = (self.variants.len() > 1).then(|| {
74+
let (cmp, chain) = if eq {
75+
(quote! { == }, quote! { && })
76+
} else {
77+
(quote! { != }, quote! { || })
78+
};
79+
80+
let discriminants_cmp = (self.variants.len() > 1).then(|| {
6881
quote! {
69-
derive_more::core::mem::discriminant(self) ==
82+
derive_more::core::mem::discriminant(self) #cmp
7083
derive_more::core::mem::discriminant(__other)
7184
}
7285
});
7386

74-
let matched_variants = self
87+
let match_arms = self
7588
.variants
7689
.iter()
7790
.filter_map(|(variant, fields)| {
7891
if fields.is_empty() {
7992
return None;
8093
}
94+
8195
let variant = variant.map(|variant| quote! { :: #variant });
8296
let self_pattern = fields.arm_pattern("__self_");
8397
let other_pattern = fields.arm_pattern("__other_");
84-
let val_eqs = (0..fields.len()).map(|num| {
85-
let self_val = format_ident!("__self_{num}");
86-
let other_val = format_ident!("__other_{num}");
87-
quote! { #self_val == #other_val }
88-
});
98+
99+
let mut val_eqs = (0..fields.len())
100+
.map(|num| {
101+
let self_val = format_ident!("__self_{num}");
102+
let other_val = format_ident!("__other_{num}");
103+
punctuated::Pair::Punctuated(
104+
quote! { #self_val #cmp #other_val },
105+
&chain,
106+
)
107+
})
108+
.collect::<Punctuated<TokenStream, _>>();
109+
_ = val_eqs.pop_punct();
110+
89111
Some(quote! {
90-
(Self #variant #self_pattern, Self #variant #other_pattern) => {
91-
#( #val_eqs )&&*
92-
}
112+
(Self #variant #self_pattern, Self #variant #other_pattern) => { #val_eqs }
93113
})
94114
})
95115
.collect::<Vec<_>>();
96-
let match_expr = (!matched_variants.is_empty()).then(|| {
97-
let always_true_arm =
98-
(matched_variants.len() != self.variants.len()).then(|| {
99-
quote! { _ => true }
100-
});
116+
let match_expr = (!match_arms.is_empty()).then(|| {
117+
let no_fields_arm = (match_arms.len() != self.variants.len()).then(|| {
118+
quote! { _ => #no_fields_result }
119+
});
101120
let unreachable_arm = (self.variants.len() > 1
102-
&& always_true_arm.is_none())
121+
&& no_fields_arm.is_none())
103122
.then(|| {
104123
quote! {
105124
// SAFETY: This arm is never reachable, but is required by the expanded
@@ -110,19 +129,25 @@ impl StructuralExpansion<'_> {
110129

111130
quote! {
112131
match (self, __other) {
113-
#( #matched_variants , )*
114-
#always_true_arm
132+
#( #match_arms , )*
133+
#no_fields_arm
115134
#unreachable_arm
116135
}
117136
}
118137
});
119138

120-
let and = (discriminants_eq.is_some() && match_expr.is_some())
121-
.then_some(token::AndAnd::default());
122-
123-
quote! {
124-
#discriminants_eq #and #match_expr
139+
// If there is only `mem::discriminant()` comparison, there is no need to generate `ne()`
140+
// method in the expansion, as its default implementation will do just fine.
141+
if !eq && discriminants_cmp.is_some() && match_expr.is_none() {
142+
return None;
125143
}
144+
145+
let chain =
146+
(discriminants_cmp.is_some() && match_expr.is_some()).then_some(chain);
147+
148+
Some(quote! {
149+
#discriminants_cmp #chain #match_expr
150+
})
126151
}
127152
}
128153

@@ -143,17 +168,26 @@ impl ToTokens for StructuralExpansion<'_> {
143168
}
144169
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
145170

146-
let eq_body = self.eq_body();
171+
let eq_method = self.body(true).map(|body| {
172+
quote! {
173+
#[inline]
174+
fn eq(&self, __other: &Self) -> bool { #body }
175+
}
176+
});
177+
let ne_method = self.body(false).map(|body| {
178+
quote! {
179+
#[inline]
180+
fn ne(&self, __other: &Self) -> bool { #body }
181+
}
182+
});
147183

148184
quote! {
149185
#[automatically_derived]
150186
impl #impl_generics derive_more::core::cmp::PartialEq for #ty #ty_generics
151187
#where_clause
152188
{
153-
#[inline]
154-
fn eq(&self, __other: &Self) -> bool {
155-
#eq_body
156-
}
189+
#eq_method
190+
#ne_method
157191
}
158192
}
159193
.to_tokens(tokens);

tests/partial_eq.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
11
#![cfg_attr(not(feature = "std"), no_std)]
22
#![allow(dead_code)] // some code is tested for type checking only
33

4+
/// Since [`assert_ne!()`] macro in [`core`] doesn't use `$left != $right` comparison, but rather
5+
/// checks as `!($left == $right)`, it should be redefined for tests to consider actual
6+
/// [`PartialEq::ne()`] implementations.
7+
///
8+
/// [`assert_ne!()`]: core::assert_ne
9+
#[macro_export]
10+
macro_rules! assert_ne {
11+
($left:expr, $right:expr $(,)?) => {
12+
assert!($left != $right)
13+
};
14+
}
15+
416
mod structs {
517
mod structural {
618
use derive_more::PartialEq;

0 commit comments

Comments
 (0)