Skip to content

Commit 461435a

Browse files
committed
Enforce builtin binop expectations on single references
Also don't enforce them on non-builtin types
1 parent fa87462 commit 461435a

File tree

3 files changed

+243
-33
lines changed

3 files changed

+243
-33
lines changed

crates/hir-ty/src/chalk_ext.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! Various extensions traits for Chalk types.
22
3-
use chalk_ir::{FloatTy, IntTy, Mutability, Scalar, UintTy};
3+
use chalk_ir::{FloatTy, IntTy, Mutability, Scalar, TyVariableKind, UintTy};
44
use hir_def::{
55
builtin_type::{BuiltinFloat, BuiltinInt, BuiltinType, BuiltinUint},
66
generics::TypeOrConstParamData,
@@ -18,6 +18,8 @@ use crate::{
1818

1919
pub trait TyExt {
2020
fn is_unit(&self) -> bool;
21+
fn is_integral(&self) -> bool;
22+
fn is_floating_point(&self) -> bool;
2123
fn is_never(&self) -> bool;
2224
fn is_unknown(&self) -> bool;
2325
fn is_ty_var(&self) -> bool;
@@ -51,6 +53,21 @@ impl TyExt for Ty {
5153
matches!(self.kind(Interner), TyKind::Tuple(0, _))
5254
}
5355

56+
fn is_integral(&self) -> bool {
57+
matches!(
58+
self.kind(Interner),
59+
TyKind::Scalar(Scalar::Int(_) | Scalar::Uint(_))
60+
| TyKind::InferenceVar(_, TyVariableKind::Integer)
61+
)
62+
}
63+
64+
fn is_floating_point(&self) -> bool {
65+
matches!(
66+
self.kind(Interner),
67+
TyKind::Scalar(Scalar::Float(_)) | TyKind::InferenceVar(_, TyVariableKind::Float)
68+
)
69+
}
70+
5471
fn is_never(&self) -> bool {
5572
matches!(self.kind(Interner), TyKind::Never)
5673
}

crates/hir-ty/src/infer/expr.rs

Lines changed: 127 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,11 +1071,9 @@ impl<'a> InferenceContext<'a> {
10711071

10721072
let ret_ty = self.normalize_associated_types_in(ret_ty);
10731073

1074-
// use knowledge of built-in binary ops, which can sometimes help inference
1075-
if let Some(builtin_rhs) = self.builtin_binary_op_rhs_expectation(op, lhs_ty.clone()) {
1076-
self.unify(&builtin_rhs, &rhs_ty);
1077-
}
1078-
if let Some(builtin_ret) = self.builtin_binary_op_return_ty(op, lhs_ty, rhs_ty) {
1074+
if self.is_builtin_binop(&lhs_ty, &rhs_ty, op) {
1075+
// use knowledge of built-in binary ops, which can sometimes help inference
1076+
let builtin_ret = self.enforce_builtin_binop_types(&lhs_ty, &rhs_ty, op);
10791077
self.unify(&builtin_ret, &ret_ty);
10801078
}
10811079

@@ -1545,7 +1543,10 @@ impl<'a> InferenceContext<'a> {
15451543
fn builtin_binary_op_rhs_expectation(&mut self, op: BinaryOp, lhs_ty: Ty) -> Option<Ty> {
15461544
Some(match op {
15471545
BinaryOp::LogicOp(..) => TyKind::Scalar(Scalar::Bool).intern(Interner),
1548-
BinaryOp::Assignment { op: None } => lhs_ty,
1546+
BinaryOp::Assignment { op: None } => {
1547+
stdx::never!("Simple assignment operator is not binary op.");
1548+
return None;
1549+
}
15491550
BinaryOp::CmpOp(CmpOp::Eq { .. }) => match self
15501551
.resolve_ty_shallow(&lhs_ty)
15511552
.kind(Interner)
@@ -1565,6 +1566,126 @@ impl<'a> InferenceContext<'a> {
15651566
})
15661567
}
15671568

1569+
/// Dereferences a single level of immutable referencing.
1570+
fn deref_ty_if_possible(&mut self, ty: &Ty) -> Ty {
1571+
let ty = self.resolve_ty_shallow(ty);
1572+
match ty.kind(Interner) {
1573+
TyKind::Ref(Mutability::Not, _, inner) => self.resolve_ty_shallow(inner),
1574+
_ => ty,
1575+
}
1576+
}
1577+
1578+
/// Enforces expectations on lhs type and rhs type depending on the operator and returns the
1579+
/// output type of the binary op.
1580+
fn enforce_builtin_binop_types(&mut self, lhs: &Ty, rhs: &Ty, op: BinaryOp) -> Ty {
1581+
// Special-case a single layer of referencing, so that things like `5.0 + &6.0f32` work (See rust-lang/rust#57447).
1582+
let lhs = self.deref_ty_if_possible(lhs);
1583+
let rhs = self.deref_ty_if_possible(rhs);
1584+
1585+
let (op, is_assign) = match op {
1586+
BinaryOp::Assignment { op: Some(inner) } => (BinaryOp::ArithOp(inner), true),
1587+
_ => (op, false),
1588+
};
1589+
1590+
let output_ty = match op {
1591+
BinaryOp::LogicOp(_) => {
1592+
let bool_ = self.result.standard_types.bool_.clone();
1593+
self.unify(&lhs, &bool_);
1594+
self.unify(&rhs, &bool_);
1595+
bool_
1596+
}
1597+
1598+
BinaryOp::ArithOp(ArithOp::Shl | ArithOp::Shr) => {
1599+
// result type is same as LHS always
1600+
lhs
1601+
}
1602+
1603+
BinaryOp::ArithOp(_) => {
1604+
// LHS, RHS, and result will have the same type
1605+
self.unify(&lhs, &rhs);
1606+
lhs
1607+
}
1608+
1609+
BinaryOp::CmpOp(_) => {
1610+
// LHS and RHS will have the same type
1611+
self.unify(&lhs, &rhs);
1612+
self.result.standard_types.bool_.clone()
1613+
}
1614+
1615+
BinaryOp::Assignment { op: None } => {
1616+
stdx::never!("Simple assignment operator is not binary op.");
1617+
lhs
1618+
}
1619+
1620+
BinaryOp::Assignment { .. } => unreachable!("handled above"),
1621+
};
1622+
1623+
if is_assign {
1624+
self.result.standard_types.unit.clone()
1625+
} else {
1626+
output_ty
1627+
}
1628+
}
1629+
1630+
fn is_builtin_binop(&mut self, lhs: &Ty, rhs: &Ty, op: BinaryOp) -> bool {
1631+
// Special-case a single layer of referencing, so that things like `5.0 + &6.0f32` work (See rust-lang/rust#57447).
1632+
let lhs = self.deref_ty_if_possible(lhs);
1633+
let rhs = self.deref_ty_if_possible(rhs);
1634+
1635+
let op = match op {
1636+
BinaryOp::Assignment { op: Some(inner) } => BinaryOp::ArithOp(inner),
1637+
_ => op,
1638+
};
1639+
1640+
match op {
1641+
BinaryOp::LogicOp(_) => true,
1642+
1643+
BinaryOp::ArithOp(ArithOp::Shl | ArithOp::Shr) => {
1644+
lhs.is_integral() && rhs.is_integral()
1645+
}
1646+
1647+
BinaryOp::ArithOp(
1648+
ArithOp::Add | ArithOp::Sub | ArithOp::Mul | ArithOp::Div | ArithOp::Rem,
1649+
) => {
1650+
lhs.is_integral() && rhs.is_integral()
1651+
|| lhs.is_floating_point() && rhs.is_floating_point()
1652+
}
1653+
1654+
BinaryOp::ArithOp(ArithOp::BitAnd | ArithOp::BitOr | ArithOp::BitXor) => {
1655+
lhs.is_integral() && rhs.is_integral()
1656+
|| lhs.is_floating_point() && rhs.is_floating_point()
1657+
|| matches!(
1658+
(lhs.kind(Interner), rhs.kind(Interner)),
1659+
(TyKind::Scalar(Scalar::Bool), TyKind::Scalar(Scalar::Bool))
1660+
)
1661+
}
1662+
1663+
BinaryOp::CmpOp(_) => {
1664+
let is_scalar = |kind| {
1665+
matches!(
1666+
kind,
1667+
&TyKind::Scalar(_)
1668+
| TyKind::FnDef(..)
1669+
| TyKind::Function(_)
1670+
| TyKind::Raw(..)
1671+
| TyKind::InferenceVar(
1672+
_,
1673+
TyVariableKind::Integer | TyVariableKind::Float
1674+
)
1675+
)
1676+
};
1677+
is_scalar(lhs.kind(Interner)) && is_scalar(rhs.kind(Interner))
1678+
}
1679+
1680+
BinaryOp::Assignment { op: None } => {
1681+
stdx::never!("Simple assignment operator is not binary op.");
1682+
false
1683+
}
1684+
1685+
BinaryOp::Assignment { .. } => unreachable!("handled above"),
1686+
}
1687+
}
1688+
15681689
fn with_breakable_ctx<T>(
15691690
&mut self,
15701691
kind: BreakableKind,

crates/hir-ty/src/tests/traits.rs

Lines changed: 98 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3507,14 +3507,9 @@ trait Request {
35073507
fn bin_op_adt_with_rhs_primitive() {
35083508
check_infer_with_mismatches(
35093509
r#"
3510-
#[lang = "add"]
3511-
pub trait Add<Rhs = Self> {
3512-
type Output;
3513-
fn add(self, rhs: Rhs) -> Self::Output;
3514-
}
3515-
3510+
//- minicore: add
35163511
struct Wrapper(u32);
3517-
impl Add<u32> for Wrapper {
3512+
impl core::ops::Add<u32> for Wrapper {
35183513
type Output = Self;
35193514
fn add(self, rhs: u32) -> Wrapper {
35203515
Wrapper(rhs)
@@ -3527,29 +3522,106 @@ fn main(){
35273522
35283523
}"#,
35293524
expect![[r#"
3530-
72..76 'self': Self
3531-
78..81 'rhs': Rhs
3532-
192..196 'self': Wrapper
3533-
198..201 'rhs': u32
3534-
219..247 '{ ... }': Wrapper
3535-
229..236 'Wrapper': Wrapper(u32) -> Wrapper
3536-
229..241 'Wrapper(rhs)': Wrapper
3537-
237..240 'rhs': u32
3538-
259..345 '{ ...um; }': ()
3539-
269..276 'wrapped': Wrapper
3540-
279..286 'Wrapper': Wrapper(u32) -> Wrapper
3541-
279..290 'Wrapper(10)': Wrapper
3542-
287..289 '10': u32
3543-
300..303 'num': u32
3544-
311..312 '2': u32
3545-
322..325 'res': Wrapper
3546-
328..335 'wrapped': Wrapper
3547-
328..341 'wrapped + num': Wrapper
3548-
338..341 'num': u32
3525+
95..99 'self': Wrapper
3526+
101..104 'rhs': u32
3527+
122..150 '{ ... }': Wrapper
3528+
132..139 'Wrapper': Wrapper(u32) -> Wrapper
3529+
132..144 'Wrapper(rhs)': Wrapper
3530+
140..143 'rhs': u32
3531+
162..248 '{ ...um; }': ()
3532+
172..179 'wrapped': Wrapper
3533+
182..189 'Wrapper': Wrapper(u32) -> Wrapper
3534+
182..193 'Wrapper(10)': Wrapper
3535+
190..192 '10': u32
3536+
203..206 'num': u32
3537+
214..215 '2': u32
3538+
225..228 'res': Wrapper
3539+
231..238 'wrapped': Wrapper
3540+
231..244 'wrapped + num': Wrapper
3541+
241..244 'num': u32
35493542
"#]],
35503543
)
35513544
}
35523545

3546+
#[test]
3547+
fn builtin_binop_expectation_works_on_single_reference() {
3548+
check_types(
3549+
r#"
3550+
//- minicore: add
3551+
use core::ops::Add;
3552+
impl Add<i32> for i32 { type Output = i32 }
3553+
impl Add<&i32> for i32 { type Output = i32 }
3554+
impl Add<u32> for u32 { type Output = u32 }
3555+
impl Add<&u32> for u32 { type Output = u32 }
3556+
3557+
struct V<T>;
3558+
impl<T> V<T> {
3559+
fn default() -> Self { loop {} }
3560+
fn get(&self, _: &T) -> &T { loop {} }
3561+
}
3562+
3563+
fn take_u32(_: u32) {}
3564+
fn minimized() {
3565+
let v = V::default();
3566+
let p = v.get(&0);
3567+
//^ &u32
3568+
take_u32(42 + p);
3569+
}
3570+
"#,
3571+
);
3572+
}
3573+
3574+
#[test]
3575+
fn no_builtin_binop_expectation_for_general_ty_var() {
3576+
// FIXME: Ideally type mismatch should be reported on `take_u32(42 - p)`.
3577+
check_types(
3578+
r#"
3579+
//- minicore: add
3580+
use core::ops::Add;
3581+
impl Add<i32> for i32 { type Output = i32; }
3582+
impl Add<&i32> for i32 { type Output = i32; }
3583+
// This is needed to prevent chalk from giving unique solution to `i32: Add<&?0>` after applying
3584+
// fallback to integer type variable for `42`.
3585+
impl Add<&()> for i32 { type Output = (); }
3586+
3587+
struct V<T>;
3588+
impl<T> V<T> {
3589+
fn default() -> Self { loop {} }
3590+
fn get(&self) -> &T { loop {} }
3591+
}
3592+
3593+
fn take_u32(_: u32) {}
3594+
fn minimized() {
3595+
let v = V::default();
3596+
let p = v.get();
3597+
//^ &{unknown}
3598+
take_u32(42 + p);
3599+
}
3600+
"#,
3601+
);
3602+
}
3603+
3604+
#[test]
3605+
fn no_builtin_binop_expectation_for_non_builtin_types() {
3606+
check_no_mismatches(
3607+
r#"
3608+
//- minicore: default, eq
3609+
struct S;
3610+
impl Default for S { fn default() -> Self { S } }
3611+
impl Default for i32 { fn default() -> Self { 0 } }
3612+
impl PartialEq<S> for i32 { fn eq(&self, _: &S) -> bool { true } }
3613+
impl PartialEq<i32> for i32 { fn eq(&self, _: &S) -> bool { true } }
3614+
3615+
fn take_s(_: S) {}
3616+
fn test() {
3617+
let s = Default::default();
3618+
let _eq = 0 == s;
3619+
take_s(s);
3620+
}
3621+
"#,
3622+
)
3623+
}
3624+
35533625
#[test]
35543626
fn array_length() {
35553627
check_infer(

0 commit comments

Comments
 (0)