Skip to content

Simplify discriminant codegen for niche-encoded variants which don't wrap across an integer boundary #143784

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion compiler/rustc_abi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use std::fmt;
#[cfg(feature = "nightly")]
use std::iter::Step;
use std::num::{NonZeroUsize, ParseIntError};
use std::ops::{Add, AddAssign, Deref, Mul, RangeInclusive, Sub};
use std::ops::{Add, AddAssign, Deref, Mul, RangeFull, RangeInclusive, Sub};
use std::str::FromStr;

use bitflags::bitflags;
Expand Down Expand Up @@ -1391,12 +1391,45 @@ impl WrappingRange {
}

/// Returns `true` if `size` completely fills the range.
///
/// Note that this is *not* the same as `self == WrappingRange::full(size)`.
/// Niche calculations can produce full ranges which are not the canonical one;
/// for example `Option<NonZero<u16>>` gets `valid_range: (..=0) | (1..)`.
#[inline]
fn is_full_for(&self, size: Size) -> bool {
let max_value = size.unsigned_int_max();
debug_assert!(self.start <= max_value && self.end <= max_value);
self.start == (self.end.wrapping_add(1) & max_value)
}

/// Checks whether this range is considered non-wrapping when the values are
/// interpreted as *unsigned* numbers of width `size`.
///
/// Returns `Ok(true)` if there's no wrap-around, `Ok(false)` if there is,
/// and `Err(..)` if the range is full so it depends how you think about it.
#[inline]
pub fn no_unsigned_wraparound(&self, size: Size) -> Result<bool, RangeFull> {
if self.is_full_for(size) { Err(..) } else { Ok(self.start <= self.end) }
}

/// Checks whether this range is considered non-wrapping when the values are
/// interpreted as *signed* numbers of width `size`.
///
/// This is heavily dependent on the `size`, as `100..=200` does wrap when
/// interpreted as `i8`, but doesn't when interpreted as `i16`.
///
/// Returns `Ok(true)` if there's no wrap-around, `Ok(false)` if there is,
/// and `Err(..)` if the range is full so it depends how you think about it.
#[inline]
pub fn no_signed_wraparound(&self, size: Size) -> Result<bool, RangeFull> {
if self.is_full_for(size) {
Err(..)
} else {
let start: i128 = size.sign_extend(self.start);
let end: i128 = size.sign_extend(self.end);
Ok(start <= end)
}
}
}

impl fmt::Debug for WrappingRange {
Expand Down
125 changes: 96 additions & 29 deletions compiler/rustc_codegen_ssa/src/mir/operand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {
// value and the variant index match, since that's all `Niche` can encode.

let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
let niche_start_const = bx.cx().const_uint_big(tag_llty, niche_start);

// We have a subrange `niche_start..=niche_end` inside `range`.
// If the value of the tag is inside this subrange, it's a
Expand All @@ -511,35 +512,88 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {
// } else {
// untagged_variant
// }
let niche_start = bx.cx().const_uint_big(tag_llty, niche_start);
let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start);
let is_niche = bx.icmp(IntPredicate::IntEQ, tag, niche_start_const);
let tagged_discr =
bx.cx().const_uint(cast_to, niche_variants.start().as_u32() as u64);
(is_niche, tagged_discr, 0)
} else {
// The special cases don't apply, so we'll have to go with
// the general algorithm.
let relative_discr = bx.sub(tag, bx.cx().const_uint_big(tag_llty, niche_start));
// With multiple niched variants we'll have to actually compute
// the variant index from the stored tag.
//
// However, there's still one small optimization we can often do for
// determining *whether* a tag value is a natural value or a niched
// variant. The general algorithm involves a subtraction that often
// wraps in practice, making it tricky to analyse. However, in cases
// where there are few enough possible values of the tag that it doesn't
// need to wrap around, we can instead just look for the contiguous
// tag values on the end of the range with a single comparison.
//
// For example, take the type `enum Demo { A, B, Untagged(bool) }`.
// The `bool` is {0, 1}, and the two other variants are given the
// tags {2, 3} respectively. That means the `tag_range` is
// `[0, 3]`, which doesn't wrap as unsigned (nor as signed), so
// we can test for the niched variants with just `>= 2`.
//
// That means we're looking either for the niche values *above*
// the natural values of the untagged variant:
//
// niche_start niche_end
// | |
// v v
// MIN -------------+---------------------------+---------- MAX
// ^ | is niche |
// | +---------------------------+
// | |
// tag_range.start tag_range.end
//
// Or *below* the natural values:
//
// niche_start niche_end
// | |
// v v
// MIN ----+-----------------------+---------------------- MAX
// | is niche | ^
// +-----------------------+ |
// | |
// tag_range.start tag_range.end
//
// With those two options and having the flexibility to choose
// between a signed or unsigned comparison on the tag, that
// covers most realistic scenarios. The tests have a (contrived)
// example of a 1-byte enum with over 128 niched variants which
// wraps both as signed as unsigned, though, and for something
// like that we're stuck with the general algorithm.

let tag_range = tag_scalar.valid_range(&dl);
let tag_size = tag_scalar.size(&dl);
let niche_end = u128::from(relative_max).wrapping_add(niche_start);
let niche_end = tag_size.truncate(niche_end);

let relative_discr = bx.sub(tag, niche_start_const);
let cast_tag = bx.intcast(relative_discr, cast_to, false);
let is_niche = bx.icmp(
IntPredicate::IntULE,
relative_discr,
bx.cx().const_uint(tag_llty, relative_max as u64),
);

// Thanks to parameter attributes and load metadata, LLVM already knows
// the general valid range of the tag. It's possible, though, for there
// to be an impossible value *in the middle*, which those ranges don't
// communicate, so it's worth an `assume` to let the optimizer know.
if niche_variants.contains(&untagged_variant)
&& bx.cx().sess().opts.optimize != OptLevel::No
{
let impossible =
u64::from(untagged_variant.as_u32() - niche_variants.start().as_u32());
let impossible = bx.cx().const_uint(tag_llty, impossible);
let ne = bx.icmp(IntPredicate::IntNE, relative_discr, impossible);
bx.assume(ne);
}
let is_niche = if tag_range.no_unsigned_wraparound(tag_size) == Ok(true) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some description of the special case here? Does a diagram like this make sense?

//         niche_start                       niche_end           
//              |                                |               
//              v                                v               
// 0u8----------+--------------------------------+----------255u8
//         ^    |            is niche            |               
//         |    +--------------------------------+               
//         |                                     |               
// tag_range.start                        tag_range.end          

if niche_start == tag_range.start {
let niche_end_const = bx.cx().const_uint_big(tag_llty, niche_end);
bx.icmp(IntPredicate::IntULE, tag, niche_end_const)
} else {
assert_eq!(niche_end, tag_range.end);
bx.icmp(IntPredicate::IntUGE, tag, niche_start_const)
}
} else if tag_range.no_signed_wraparound(tag_size) == Ok(true) {
if niche_start == tag_range.start {
let niche_end_const = bx.cx().const_uint_big(tag_llty, niche_end);
bx.icmp(IntPredicate::IntSLE, tag, niche_end_const)
} else {
assert_eq!(niche_end, tag_range.end);
bx.icmp(IntPredicate::IntSGE, tag, niche_start_const)
}
} else {
bx.icmp(
IntPredicate::IntULE,
relative_discr,
bx.cx().const_uint(tag_llty, relative_max as u64),
)
};

(is_niche, cast_tag, niche_variants.start().as_u32() as u128)
};
Expand All @@ -550,11 +604,24 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {
bx.add(tagged_discr, bx.cx().const_uint_big(cast_to, delta))
};

let discr = bx.select(
is_niche,
tagged_discr,
bx.cx().const_uint(cast_to, untagged_variant.as_u32() as u64),
);
let untagged_variant_const =
bx.cx().const_uint(cast_to, u64::from(untagged_variant.as_u32()));

// Thanks to parameter attributes and load metadata, LLVM already knows
// the general valid range of the tag. It's possible, though, for there
// to be an impossible value *in the middle*, which those ranges don't
// communicate, so it's worth an `assume` to let the optimizer know.
// Most importantly, this means when optimizing a variant test like
// `SELECT(is_niche, complex, CONST) == CONST` it's ok to simplify that
// to `!is_niche` because the `complex` part can't possibly match.
if niche_variants.contains(&untagged_variant)
&& bx.cx().sess().opts.optimize != OptLevel::No
{
let ne = bx.icmp(IntPredicate::IntNE, tagged_discr, untagged_variant_const);
bx.assume(ne);
}

let discr = bx.select(is_niche, tagged_discr, untagged_variant_const);

// In principle we could insert assumes on the possible range of `discr`, but
// currently in LLVM this isn't worth it because the original `tag` will
Expand Down
Loading
Loading