Skip to content

Commit cf92372

Browse files
committed
Fix discriminant handling
1 parent 9723c79 commit cf92372

File tree

2 files changed

+153
-45
lines changed

2 files changed

+153
-45
lines changed

src/common.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ pub(crate) fn codegen_icmp_imm(
162162
}
163163
}
164164
} else {
165-
let rhs = i64::try_from(rhs).expect("codegen_icmp_imm rhs out of range for <128bit int");
165+
let rhs = rhs as i64; // Truncates on purpose in case rhs is actually an unsigned value
166166
fx.bcx.ins().icmp_imm(intcc, lhs, rhs)
167167
}
168168
}

src/discriminant.rs

Lines changed: 152 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! Handling of enum discriminants
22
//!
3-
//! Adapted from <https://github.com/rust-lang/rust/blob/d760df5aea483aae041c9a241e7acacf48f75035/src/librustc_codegen_ssa/mir/place.rs>
3+
//! Adapted from <https://github.com/rust-lang/rust/blob/31c0645b9d2539f47eecb096142474b29dc542f7/compiler/rustc_codegen_ssa/src/mir/place.rs>
4+
//! (<https://github.com/rust-lang/rust/pull/104535>)
45
56
use rustc_target::abi::{Int, TagEncoding, Variants};
67

@@ -47,13 +48,18 @@ pub(crate) fn codegen_set_discriminant<'tcx>(
4748
} => {
4849
if variant_index != untagged_variant {
4950
let niche = place.place_field(fx, mir::Field::new(tag_field));
51+
let niche_type = fx.clif_type(niche.layout().ty).unwrap();
5052
let niche_value = variant_index.as_u32() - niche_variants.start().as_u32();
51-
let niche_value = ty::ScalarInt::try_from_uint(
52-
u128::from(niche_value).wrapping_add(niche_start),
53-
niche.layout().size,
54-
)
55-
.unwrap();
56-
let niche_llval = CValue::const_val(fx, niche.layout(), niche_value);
53+
let niche_value = (niche_value as u128).wrapping_add(niche_start);
54+
let niche_value = match niche_type {
55+
types::I128 => {
56+
let lsb = fx.bcx.ins().iconst(types::I64, niche_value as u64 as i64);
57+
let msb = fx.bcx.ins().iconst(types::I64, (niche_value >> 64) as u64 as i64);
58+
fx.bcx.ins().iconcat(lsb, msb)
59+
}
60+
ty => fx.bcx.ins().iconst(ty, niche_value as i64),
61+
};
62+
let niche_llval = CValue::by_val(niche_value, niche.layout());
5763
niche.write_cvalue(fx, niche_llval);
5864
}
5965
}
@@ -96,6 +102,7 @@ pub(crate) fn codegen_get_discriminant<'tcx>(
96102
}
97103
};
98104

105+
let cast_to_size = dest_layout.layout.size();
99106
let cast_to = fx.clif_type(dest_layout.ty).unwrap();
100107

101108
// Read the tag/niche-encoded discriminant from memory.
@@ -114,21 +121,128 @@ pub(crate) fn codegen_get_discriminant<'tcx>(
114121
dest.write_cvalue(fx, res);
115122
}
116123
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start } => {
117-
// Rebase from niche values to discriminants, and check
118-
// whether the result is in range for the niche variants.
119-
120-
// We first compute the "relative discriminant" (wrt `niche_variants`),
121-
// that is, if `n = niche_variants.end() - niche_variants.start()`,
122-
// we remap `niche_start..=niche_start + n` (which may wrap around)
123-
// to (non-wrap-around) `0..=n`, to be able to check whether the
124-
// discriminant corresponds to a niche variant with one comparison.
125-
// We also can't go directly to the (variant index) discriminant
126-
// and check that it is in the range `niche_variants`, because
127-
// that might not fit in the same type, on top of needing an extra
128-
// comparison (see also the comment on `let niche_discr`).
129-
let relative_discr = if niche_start == 0 {
130-
tag
124+
let tag_size = tag_scalar.size(fx);
125+
let max_unsigned = tag_size.unsigned_int_max();
126+
let max_signed = tag_size.signed_int_max() as u128;
127+
let min_signed = max_signed + 1;
128+
let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
129+
let niche_end = niche_start.wrapping_add(relative_max as u128) & max_unsigned;
130+
let range = tag_scalar.valid_range(fx);
131+
132+
let sle = |lhs: u128, rhs: u128| -> bool {
133+
// Signed and unsigned comparisons give the same results,
134+
// except that in signed comparisons an integer with the
135+
// sign bit set is less than one with the sign bit clear.
136+
// Toggle the sign bit to do a signed comparison.
137+
(lhs ^ min_signed) <= (rhs ^ min_signed)
138+
};
139+
140+
// We have a subrange `niche_start..=niche_end` inside `range`.
141+
// If the value of the tag is inside this subrange, it's a
142+
// "niche value", an increment of the discriminant. Otherwise it
143+
// indicates the untagged variant.
144+
// A general algorithm to extract the discriminant from the tag
145+
// is:
146+
// relative_tag = tag - niche_start
147+
// is_niche = relative_tag <= (ule) relative_max
148+
// discr = if is_niche {
149+
// cast(relative_tag) + niche_variants.start()
150+
// } else {
151+
// untagged_variant
152+
// }
153+
// However, we will likely be able to emit simpler code.
154+
155+
// Find the least and greatest values in `range`, considered
156+
// both as signed and unsigned.
157+
let (low_unsigned, high_unsigned) =
158+
if range.start <= range.end { (range.start, range.end) } else { (0, max_unsigned) };
159+
let (low_signed, high_signed) = if sle(range.start, range.end) {
160+
(range.start, range.end)
161+
} else {
162+
(min_signed, max_signed)
163+
};
164+
165+
let niches_ule = niche_start <= niche_end;
166+
let niches_sle = sle(niche_start, niche_end);
167+
let cast_smaller = cast_to_size <= tag_size;
168+
169+
// In the algorithm above, we can change
170+
// cast(relative_tag) + niche_variants.start()
171+
// into
172+
// cast(tag + (niche_variants.start() - niche_start))
173+
// if either the casted type is no larger than the original
174+
// type, or if the niche values are contiguous (in either the
175+
// signed or unsigned sense).
176+
let can_incr = cast_smaller || niches_ule || niches_sle;
177+
178+
let data_for_boundary_niche = || -> Option<(IntCC, u128)> {
179+
if !can_incr {
180+
None
181+
} else if niche_start == low_unsigned {
182+
Some((IntCC::UnsignedLessThanOrEqual, niche_end))
183+
} else if niche_end == high_unsigned {
184+
Some((IntCC::UnsignedGreaterThanOrEqual, niche_start))
185+
} else if niche_start == low_signed {
186+
Some((IntCC::SignedLessThanOrEqual, niche_end))
187+
} else if niche_end == high_signed {
188+
Some((IntCC::SignedGreaterThanOrEqual, niche_start))
189+
} else {
190+
None
191+
}
192+
};
193+
194+
let (is_niche, tagged_discr, delta) = if relative_max == 0 {
195+
// Best case scenario: only one tagged variant. This will
196+
// likely become just a comparison and a jump.
197+
// The algorithm is:
198+
// is_niche = tag == niche_start
199+
// discr = if is_niche {
200+
// niche_start
201+
// } else {
202+
// untagged_variant
203+
// }
204+
let is_niche = codegen_icmp_imm(fx, IntCC::Equal, tag, niche_start as i128);
205+
let tagged_discr =
206+
fx.bcx.ins().iconst(cast_to, niche_variants.start().as_u32() as i64);
207+
(is_niche, tagged_discr, 0)
208+
} else if let Some((predicate, constant)) = data_for_boundary_niche() {
209+
// The niche values are either the lowest or the highest in
210+
// `range`. We can avoid the first subtraction in the
211+
// algorithm.
212+
// The algorithm is now this:
213+
// is_niche = tag <= niche_end
214+
// discr = if is_niche {
215+
// cast(tag + (niche_variants.start() - niche_start))
216+
// } else {
217+
// untagged_variant
218+
// }
219+
// (the first line may instead be tag >= niche_start,
220+
// and may be a signed or unsigned comparison)
221+
// The arithmetic must be done before the cast, so we can
222+
// have the correct wrapping behavior. See issue #104519 for
223+
// the consequences of getting this wrong.
224+
let is_niche = codegen_icmp_imm(fx, predicate, tag, constant as i128);
225+
let delta = (niche_variants.start().as_u32() as u128).wrapping_sub(niche_start);
226+
let incr_tag = if delta == 0 {
227+
tag
228+
} else {
229+
let delta = match fx.bcx.func.dfg.value_type(tag) {
230+
types::I128 => {
231+
let lsb = fx.bcx.ins().iconst(types::I64, delta as u64 as i64);
232+
let msb = fx.bcx.ins().iconst(types::I64, (delta >> 64) as u64 as i64);
233+
fx.bcx.ins().iconcat(lsb, msb)
234+
}
235+
ty => fx.bcx.ins().iconst(ty, delta as i64),
236+
};
237+
fx.bcx.ins().iadd(tag, delta)
238+
};
239+
240+
let cast_tag = clif_intcast(fx, incr_tag, cast_to, !niches_ule);
241+
242+
(is_niche, cast_tag, 0)
131243
} else {
244+
// The special cases don't apply, so we'll have to go with
245+
// the general algorithm.
132246
let niche_start = match fx.bcx.func.dfg.value_type(tag) {
133247
types::I128 => {
134248
let lsb = fx.bcx.ins().iconst(types::I64, niche_start as u64 as i64);
@@ -138,40 +252,34 @@ pub(crate) fn codegen_get_discriminant<'tcx>(
138252
}
139253
ty => fx.bcx.ins().iconst(ty, niche_start as i64),
140254
};
141-
fx.bcx.ins().isub(tag, niche_start)
142-
};
143-
let relative_max = niche_variants.end().as_u32() - niche_variants.start().as_u32();
144-
let is_niche = {
145-
codegen_icmp_imm(
255+
let relative_discr = fx.bcx.ins().isub(tag, niche_start);
256+
let cast_tag = clif_intcast(fx, relative_discr, cast_to, false);
257+
let is_niche = crate::common::codegen_icmp_imm(
146258
fx,
147259
IntCC::UnsignedLessThanOrEqual,
148260
relative_discr,
149261
i128::from(relative_max),
150-
)
262+
);
263+
(is_niche, cast_tag, niche_variants.start().as_u32() as u128)
151264
};
152265

153-
// NOTE(eddyb) this addition needs to be performed on the final
154-
// type, in case the niche itself can't represent all variant
155-
// indices (e.g. `u8` niche with more than `256` variants,
156-
// but enough uninhabited variants so that the remaining variants
157-
// fit in the niche).
158-
// In other words, `niche_variants.end - niche_variants.start`
159-
// is representable in the niche, but `niche_variants.end`
160-
// might not be, in extreme cases.
161-
let niche_discr = {
162-
let relative_discr = if relative_max == 0 {
163-
// HACK(eddyb) since we have only one niche, we know which
164-
// one it is, and we can avoid having a dynamic value here.
165-
fx.bcx.ins().iconst(cast_to, 0)
166-
} else {
167-
clif_intcast(fx, relative_discr, cast_to, false)
266+
let tagged_discr = if delta == 0 {
267+
tagged_discr
268+
} else {
269+
let delta = match cast_to {
270+
types::I128 => {
271+
let lsb = fx.bcx.ins().iconst(types::I64, delta as u64 as i64);
272+
let msb = fx.bcx.ins().iconst(types::I64, (delta >> 64) as u64 as i64);
273+
fx.bcx.ins().iconcat(lsb, msb)
274+
}
275+
ty => fx.bcx.ins().iconst(ty, delta as i64),
168276
};
169-
fx.bcx.ins().iadd_imm(relative_discr, i64::from(niche_variants.start().as_u32()))
277+
fx.bcx.ins().iadd(tagged_discr, delta)
170278
};
171279

172280
let untagged_variant =
173281
fx.bcx.ins().iconst(cast_to, i64::from(untagged_variant.as_u32()));
174-
let discr = fx.bcx.ins().select(is_niche, niche_discr, untagged_variant);
282+
let discr = fx.bcx.ins().select(is_niche, tagged_discr, untagged_variant);
175283
let res = CValue::by_val(discr, dest_layout);
176284
dest.write_cvalue(fx, res);
177285
}

0 commit comments

Comments
 (0)