Skip to content

Commit 7b05367

Browse files
authored
[ADT] Fix specialization of ValueIsPresent for PointerUnion (#121847)
Two instances of `PointerUnion` with different active members and null value compare unequal. Currently, this results in counterintuitive behavior when using functions from `Casting.h`, e.g.: ```C++ PointerUnion<int *, float *> U; // U = (int *)nullptr; dyn_cast<int *>(U); // Aborts dyn_cast<float *>(U); // Aborts U = (float *)nullptr; dyn_cast<int *>(U); // OK dyn_cast<float *>(U); // OK ``` `dyn_cast` should abort in all cases because the argument is null. Currently, it aborts only if the first member is active. This happens because the partial template specialization of `ValueIsPresent` for nullable types compares the union with a union constructed from nullptr, and the two unions compare equal only if their active members are the same. This patch changed the specialization of `ValueIsPresent` for nullable types to make `isPresent()` return false for all possible null values of a PointerUnion, and fixes two places where the old behavior was exploited. Pull Request: #121847
1 parent 799e988 commit 7b05367

File tree

4 files changed

+13
-8
lines changed

4 files changed

+13
-8
lines changed

llvm/include/llvm/Support/Casting.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -614,12 +614,12 @@ template <typename T> struct ValueIsPresent<std::optional<T>> {
614614
static inline decltype(auto) unwrapValue(std::optional<T> &t) { return *t; }
615615
};
616616

617-
// If something is "nullable" then we just compare it to nullptr to see if it
618-
// exists.
617+
// If something is "nullable" then we just cast it to bool to see if it exists.
619618
template <typename T>
620-
struct ValueIsPresent<T, std::enable_if_t<IsNullable<T>>> {
619+
struct ValueIsPresent<
620+
T, std::enable_if_t<IsNullable<T> && std::is_constructible_v<bool, T>>> {
621621
using UnwrappedType = T;
622-
static inline bool isPresent(const T &t) { return t != T(nullptr); }
622+
static inline bool isPresent(const T &t) { return static_cast<bool>(t); }
623623
static inline decltype(auto) unwrapValue(T &t) { return t; }
624624
};
625625

llvm/lib/CodeGen/RegisterBankInfo.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,10 @@ const TargetRegisterClass *RegisterBankInfo::constrainGenericRegister(
134134

135135
// If the register already has a class, fallback to MRI::constrainRegClass.
136136
auto &RegClassOrBank = MRI.getRegClassOrRegBank(Reg);
137-
if (isa<const TargetRegisterClass *>(RegClassOrBank))
137+
if (isa_and_present<const TargetRegisterClass *>(RegClassOrBank))
138138
return MRI.constrainRegClass(Reg, &RC);
139139

140-
const RegisterBank *RB = cast<const RegisterBank *>(RegClassOrBank);
140+
const auto *RB = dyn_cast_if_present<const RegisterBank *>(RegClassOrBank);
141141
// Otherwise, all we can do is ensure the bank covers the class, and set it.
142142
if (RB && !RB->covers(RC))
143143
return nullptr;

llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3708,10 +3708,10 @@ const TargetRegisterClass *
37083708
SIRegisterInfo::getConstrainedRegClassForOperand(const MachineOperand &MO,
37093709
const MachineRegisterInfo &MRI) const {
37103710
const RegClassOrRegBank &RCOrRB = MRI.getRegClassOrRegBank(MO.getReg());
3711-
if (const RegisterBank *RB = dyn_cast<const RegisterBank *>(RCOrRB))
3711+
if (const auto *RB = dyn_cast_if_present<const RegisterBank *>(RCOrRB))
37123712
return getRegClassForTypeOnBank(MRI.getType(MO.getReg()), *RB);
37133713

3714-
if (const auto *RC = dyn_cast<const TargetRegisterClass *>(RCOrRB))
3714+
if (const auto *RC = dyn_cast_if_present<const TargetRegisterClass *>(RCOrRB))
37153715
return getAllocatableClass(RC);
37163716

37173717
return nullptr;

llvm/unittests/ADT/PointerUnionTest.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,11 @@ TEST_F(PointerUnionTest, NewCastInfra) {
208208
EXPECT_FALSE(isa<float *>(d4null));
209209
EXPECT_FALSE(isa<long long *>(d4null));
210210

211+
EXPECT_FALSE(isa_and_present<int *>(i4null));
212+
EXPECT_FALSE(isa_and_present<float *>(f4null));
213+
EXPECT_FALSE(isa_and_present<long long *>(l4null));
214+
EXPECT_FALSE(isa_and_present<double *>(d4null));
215+
211216
// test cast<>
212217
EXPECT_EQ(cast<float *>(a), &f);
213218
EXPECT_EQ(cast<int *>(b), &i);

0 commit comments

Comments
 (0)