Skip to content

Commit 5a56ecd

Browse files
committed
inference: apply a limit to permitting typesubtract for tuples (from #35600)
1 parent 811b3a3 commit 5a56ecd

File tree

4 files changed

+19
-16
lines changed

4 files changed

+19
-16
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -621,7 +621,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
621621
while valtype !== Any
622622
stateordonet = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], sv).rt
623623
stateordonet = widenconst(stateordonet)
624-
nounion = typesubtract(stateordonet, Nothing)
624+
nounion = typesubtract(stateordonet, Nothing, 0)
625625
if !isa(nounion, DataType) || !(nounion <: Tuple) || isvatuple(nounion) || length(nounion.parameters) != 2
626626
valtype = Any
627627
break
@@ -814,7 +814,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::U
814814
tty_lb = tty_ub # TODO: this would be wrong if !isexact_tty, but instanceof_tfunc doesn't preserve this info
815815
if !has_free_typevars(tty_lb) && !has_free_typevars(tty_ub)
816816
ifty = typeintersect(aty, tty_ub)
817-
elty = typesubtract(aty, tty_lb)
817+
elty = typesubtract(aty, tty_lb, InferenceParams(interp).MAX_UNION_SPLITTING)
818818
return Conditional(a, ifty, elty)
819819
end
820820
end
@@ -831,7 +831,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::U
831831
elseif rt === Const(true)
832832
bty = Union{}
833833
elseif bty isa Type && isdefined(typeof(aty.val), :instance) # can only widen a if it is a singleton
834-
bty = typesubtract(bty, typeof(aty.val))
834+
bty = typesubtract(bty, typeof(aty.val), InferenceParams(interp).MAX_UNION_SPLITTING)
835835
end
836836
return Conditional(b, aty, bty)
837837
end
@@ -841,7 +841,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::U
841841
elseif rt === Const(true)
842842
aty = Union{}
843843
elseif aty isa Type && isdefined(typeof(bty.val), :instance) # same for b
844-
aty = typesubtract(aty, typeof(bty.val))
844+
aty = typesubtract(aty, typeof(bty.val), InferenceParams(interp).MAX_UNION_SPLITTING)
845845
end
846846
return Conditional(a, bty, aty)
847847
end

base/compiler/typeutils.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,20 +63,22 @@ end
6363

6464
# return an upper-bound on type `a` with type `b` removed
6565
# such that `return <: a` && `Union{return, b} == Union{a, b}`
66-
function typesubtract(@nospecialize(a), @nospecialize(b))
66+
function typesubtract(@nospecialize(a), @nospecialize(b), MAX_UNION_SPLITTING::Int)
6767
if a <: b && isnotbrokensubtype(a, b)
6868
return Bottom
6969
end
70-
if isa(a, Union)
71-
return Union{typesubtract(a.a, b),
72-
typesubtract(a.b, b)}
70+
ua = unwrap_unionall(a)
71+
if isa(ua, Union)
72+
return Union{typesubtract(rewrap_unionall(ua.a, a), b, MAX_UNION_SPLITTING),
73+
typesubtract(rewrap_unionall(ua.b, a), b, MAX_UNION_SPLITTING)}
7374
elseif a isa DataType
74-
if b isa DataType
75-
if a.name === b.name === Tuple.name && length(a.types) == length(b.types)
75+
ub = unwrap_unionall(b)
76+
if ub isa DataType
77+
if a.name === ub.name === Tuple.name &&
78+
length(a.parameters) == length(ub.parameters) &&
79+
1 < unionsplitcost(a.parameters) <= MAX_UNION_SPLITTING
7680
ta = switchtupleunion(a)
77-
if length(ta) > 1
78-
return typesubtract(Union{ta...}, b)
79-
end
81+
return typesubtract(Union{ta...}, b, 0)
8082
end
8183
end
8284
end

stdlib/Test/src/Test.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import Distributed: myid
2828
using Random
2929
using Random: AbstractRNG, default_rng
3030
using InteractiveUtils: gen_call_with_extracted_types
31-
using Core.Compiler: typesubtract
31+
using Base: typesplit
3232

3333
const DISPLAY_FAILED = (
3434
:isequal,
@@ -1393,7 +1393,7 @@ function _inferred(ex, mod, allow = :(Union{}))
13931393
end)
13941394
@assert length(inftypes) == 1
13951395
rettype = result isa Type ? Type{result} : typeof(result)
1396-
rettype <: allow || rettype == typesubtract(inftypes[1], allow) || error("return type $rettype does not match inferred return type $(inftypes[1])")
1396+
rettype <: allow || rettype == typesplit(inftypes[1], allow) || error("return type $rettype does not match inferred return type $(inftypes[1])")
13971397
result
13981398
end
13991399
end)

test/compiler/inference.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2630,7 +2630,8 @@ end
26302630

26312631
f() = _foldl_iter(step, (Missing[],), [0.0], 1)
26322632
end
2633-
@test Core.Compiler.typesubtract(Tuple{Union{Int,Char}}, Tuple{Char}) == Tuple{Int}
2633+
@test Core.Compiler.typesubtract(Tuple{Union{Int,Char}}, Tuple{Char}, 1) == Tuple{Union{Int,Char}}
2634+
@test Core.Compiler.typesubtract(Tuple{Union{Int,Char}}, Tuple{Char}, 2) == Tuple{Int}
26342635
@test Base.return_types(Issue35566.f) == [Val{:expected}]
26352636

26362637
# constant prop through keyword arguments

0 commit comments

Comments
 (0)