Skip to content

Commit 423fd6d

Browse files
committed
improve const eq functionality in patterns
1 parent 9c3a46e commit 423fd6d

File tree

2 files changed

+33
-26
lines changed

2 files changed

+33
-26
lines changed

compiler/rustc_mir_build/src/builder/matches/test.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use rustc_middle::{bug, span_bug};
1717
use rustc_span::def_id::DefId;
1818
use rustc_span::source_map::Spanned;
1919
use rustc_span::{DUMMY_SP, Span, Symbol, sym};
20+
use rustc_trait_selection::infer::InferCtxtExt;
2021
use tracing::{debug, instrument};
2122

2223
use crate::builder::Builder;
@@ -450,7 +451,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
450451
(_, expect) = coerce(expect_ty, expect, *elem_ty);
451452
}
452453

453-
// Figure out the type on which we are calling `PartialEq`. This involves an extra wrapping
454+
// Figure out the type we are searching for traits against. This involves an extra wrapping
454455
// reference: we can only compare two `&T`, and then compare_ty will be `T`.
455456
// Make sure that we do *not* call any user-defined code here.
456457
// The only types that can end up here are string and byte literals,
@@ -466,10 +467,20 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
466467
_ => span_bug!(source_info.span, "invalid type for non-scalar compare: {}", ty),
467468
};
468469

469-
let eq_def_id =
470+
let mut cmp_trait_def_id =
470471
self.tcx.require_lang_item(LangItem::MatchLoweredCmp, Some(source_info.span));
471-
let method = trait_method(self.tcx, eq_def_id, sym::do_match, [compare_ty, compare_ty]);
472472

473+
let has_pattern_eq = self
474+
.infcx
475+
.type_implements_trait(cmp_trait_def_id, [compare_ty], self.param_env)
476+
.must_apply_modulo_regions();
477+
if !has_pattern_eq {
478+
cmp_trait_def_id =
479+
self.tcx.require_lang_item(LangItem::PartialEq, Some(source_info.span));
480+
}
481+
482+
let method =
483+
trait_method(self.tcx, cmp_trait_def_id, sym::do_match, [compare_ty, compare_ty]);
473484
let bool_ty = self.tcx.types.bool;
474485
let eq_result = self.temp(bool_ty, source_info.span);
475486
let eq_block = self.cfg.start_new_block();

library/core/src/cmp/pattern.rs

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,28 @@ where
88
fn do_match(&self, other: &Rhs) -> bool;
99
}
1010

11-
// other types?
11+
// TODO: tuples, arrays?
12+
13+
macro_rules ! impl_match_lowered_cmp_for_primitive {
14+
($($t:ty),*) => {
15+
$(
16+
impl const MatchLoweredCmp for $t {
17+
fn do_match(&self, other: &Self) -> bool {
18+
*self == *other
19+
}
20+
}
21+
)*
22+
};
23+
}
1224

13-
impl const MatchLoweredCmp for u8 {
14-
fn do_match(&self, other: &Self) -> bool {
15-
*self == *other
16-
}
25+
impl_match_lowered_cmp_for_primitive! {
26+
(), bool, char,
27+
u8, u16, u32, u64, u128, usize,
28+
i8, i16, i32, i64, i128, isize,
29+
f32, f64
1730
}
1831

19-
// shouldn't be const
32+
// note: this wasn't possible before
2033
impl const MatchLoweredCmp for str {
2134
#[rustc_allow_const_fn_unstable(const_trait_impl)]
2235
fn do_match(&self, other: &Self) -> bool {
@@ -47,20 +60,3 @@ where
4760
true
4861
}
4962
}
50-
51-
//impl<const N:usize, T> const MatchLoweredCmp for [T; N] where T: ~const MatchLoweredCmp {
52-
// #[rustc_allow_const_fn_unstable(const_trait_impl)]
53-
// fn do_match(&self, other: &Self) -> bool {
54-
// let mut i = 0;
55-
//
56-
// while i < N {
57-
// if <T as MatchLoweredCmp>::do_match(&self[i], &other[i]) == false {
58-
// return false;
59-
// }
60-
//
61-
// i += 1;
62-
// }
63-
//
64-
// true
65-
// }
66-
//}

0 commit comments

Comments
 (0)