Skip to content

Commit 3b55dae

Browse files
authored
Merge pull request #37714 from JuliaLang/jn/35600-again
Improve typesubtract for tuples (repeat #35600)
2 parents b55e250 + 47d1f62 commit 3b55dae

File tree

9 files changed

+117
-16
lines changed

9 files changed

+117
-16
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -621,7 +621,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
621621
while valtype !== Any
622622
stateordonet = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], sv).rt
623623
stateordonet = widenconst(stateordonet)
624-
nounion = typesubtract(stateordonet, Nothing)
624+
nounion = typesubtract(stateordonet, Nothing, 0)
625625
if !isa(nounion, DataType) || !(nounion <: Tuple) || isvatuple(nounion) || length(nounion.parameters) != 2
626626
valtype = Any
627627
break
@@ -814,7 +814,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::U
814814
tty_lb = tty_ub # TODO: this would be wrong if !isexact_tty, but instanceof_tfunc doesn't preserve this info
815815
if !has_free_typevars(tty_lb) && !has_free_typevars(tty_ub)
816816
ifty = typeintersect(aty, tty_ub)
817-
elty = typesubtract(aty, tty_lb)
817+
elty = typesubtract(aty, tty_lb, InferenceParams(interp).MAX_UNION_SPLITTING)
818818
return Conditional(a, ifty, elty)
819819
end
820820
end
@@ -831,7 +831,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::U
831831
elseif rt === Const(true)
832832
bty = Union{}
833833
elseif bty isa Type && isdefined(typeof(aty.val), :instance) # can only widen a if it is a singleton
834-
bty = typesubtract(bty, typeof(aty.val))
834+
bty = typesubtract(bty, typeof(aty.val), InferenceParams(interp).MAX_UNION_SPLITTING)
835835
end
836836
return Conditional(b, aty, bty)
837837
end
@@ -841,7 +841,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::U
841841
elseif rt === Const(true)
842842
aty = Union{}
843843
elseif aty isa Type && isdefined(typeof(bty.val), :instance) # same for b
844-
aty = typesubtract(aty, typeof(bty.val))
844+
aty = typesubtract(aty, typeof(bty.val), InferenceParams(interp).MAX_UNION_SPLITTING)
845845
end
846846
return Conditional(a, bty, aty)
847847
end

base/compiler/typeutils.jl

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ end
4141
# some of these queries, this check can be used to somewhat protect against making incorrect
4242
# decisions based on incorrect subtyping. Note that this check, itself, is broken for
4343
# certain combinations of `a` and `b` where one/both isa/are `Union`/`UnionAll` type(s)s.
44-
isnotbrokensubtype(@nospecialize(a), @nospecialize(b)) = (!iskindtype(b) || !isType(a) || hasuniquerep(a.parameters[1]))
44+
isnotbrokensubtype(@nospecialize(a), @nospecialize(b)) = (!iskindtype(b) || !isType(a) || hasuniquerep(a.parameters[1]) || b <: a)
4545

4646
argtypes_to_type(argtypes::Array{Any,1}) = Tuple{anymap(widenconst, argtypes)...}
4747

@@ -63,13 +63,43 @@ end
6363

6464
# return an upper-bound on type `a` with type `b` removed
6565
# such that `return <: a` && `Union{return, b} == Union{a, b}`
66-
function typesubtract(@nospecialize(a), @nospecialize(b))
66+
function typesubtract(@nospecialize(a), @nospecialize(b), MAX_UNION_SPLITTING::Int)
6767
if a <: b && isnotbrokensubtype(a, b)
6868
return Bottom
6969
end
70-
if isa(a, Union)
71-
return Union{typesubtract(a.a, b),
72-
typesubtract(a.b, b)}
70+
ua = unwrap_unionall(a)
71+
if isa(ua, Union)
72+
return Union{typesubtract(rewrap_unionall(ua.a, a), b, MAX_UNION_SPLITTING),
73+
typesubtract(rewrap_unionall(ua.b, a), b, MAX_UNION_SPLITTING)}
74+
elseif a isa DataType
75+
ub = unwrap_unionall(b)
76+
if ub isa DataType
77+
if a.name === ub.name === Tuple.name &&
78+
length(a.parameters) == length(ub.parameters)
79+
if 1 < unionsplitcost(a.parameters) <= MAX_UNION_SPLITTING
80+
ta = switchtupleunion(a)
81+
return typesubtract(Union{ta...}, b, 0)
82+
elseif b isa DataType
83+
# if exactly one element is not bottom after calling typesubtract
84+
# then the result is all of the elements as normal except that one
85+
notbottom = fill(false, length(a.parameters))
86+
for i = 1:length(notbottom)
87+
ap = a.parameters[i]
88+
bp = b.parameters[i]
89+
notbottom[i] = !(ap <: bp && isnotbrokensubtype(ap, bp))
90+
end
91+
let i = findfirst(notbottom)
92+
if i !== nothing && findnext(notbottom, i + 1) === nothing
93+
ta = collect(a.parameters)
94+
ap = a.parameters[i]
95+
bp = b.parameters[i]
96+
ta[i] = typesubtract(ap, bp, min(2, MAX_UNION_SPLITTING))
97+
return Tuple{ta...}
98+
end
99+
end
100+
end
101+
end
102+
end
73103
end
74104
return a # TODO: improve this bound?
75105
end

base/iterators.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1266,7 +1266,7 @@ else
12661266
# fixpoint.
12671267
approx_iter_type(itrT::Type) = _approx_iter_type(itrT, Base._return_type(iterate, Tuple{itrT}))
12681268
# Not actually called, just passed to return type to avoid
1269-
# having to typesubtract
1269+
# having to typesplit on Nothing
12701270
function doiterate(itr, valstate::Union{Nothing, Tuple{Any, Any}})
12711271
valstate === nothing && return nothing
12721272
val, st = valstate

base/missing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Any
3636
!!! compat "Julia 1.3"
3737
This function is exported as of Julia 1.3.
3838
"""
39-
nonmissingtype(::Type{T}) where {T} = Core.Compiler.typesubtract(T, Missing)
39+
nonmissingtype(::Type{T}) where {T} = typesplit(T, Missing)
4040

4141
function nonmissingtype_checked(T::Type)
4242
R = nonmissingtype(T)

base/promotion.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,23 @@ function typejoin(@nospecialize(a), @nospecialize(b))
121121
return Any
122122
end
123123

124+
# return an upper-bound on type `a` with type `b` removed
125+
# such that `return <: a` && `Union{return, b} == Union{a, b}`
126+
# WARNING: this is wrong for some objects for which subtyping is broken
127+
# (Core.Compiler.isnotbrokensubtype), use only simple types for `b`
128+
function typesplit(@nospecialize(a), @nospecialize(b))
129+
@_pure_meta
130+
if a <: b
131+
return Bottom
132+
end
133+
if isa(a, Union)
134+
return Union{typesplit(a.a, b),
135+
typesplit(a.b, b)}
136+
end
137+
return a
138+
end
139+
140+
124141
"""
125142
promote_typejoin(T, S)
126143
@@ -132,7 +149,7 @@ function promote_typejoin(@nospecialize(a), @nospecialize(b))
132149
c = typejoin(_promote_typesubtract(a), _promote_typesubtract(b))
133150
return Union{a, b, c}::Type
134151
end
135-
_promote_typesubtract(@nospecialize(a)) = Core.Compiler.typesubtract(a, Union{Nothing, Missing})
152+
_promote_typesubtract(@nospecialize(a)) = typesplit(a, Union{Nothing, Missing})
136153

137154

138155
# Returns length, isfixed

base/set.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,7 @@ promote_valuetype(x::Pair{K, V}, y::Pair...) where {K, V} =
570570
# Subtract singleton types which are going to be replaced
571571
function subtract_singletontype(::Type{T}, x::Pair{K}) where {T, K}
572572
if issingletontype(K)
573-
Core.Compiler.typesubtract(T, K)
573+
typesplit(T, K)
574574
else
575575
T
576576
end

base/some.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Some(::Type{T}) where {T} = Some{Type{T}}(T)
1616

1717
promote_rule(::Type{Some{T}}, ::Type{Some{S}}) where {T, S<:T} = Some{T}
1818

19-
nonnothingtype(::Type{T}) where {T} = Core.Compiler.typesubtract(T, Nothing)
19+
nonnothingtype(::Type{T}) where {T} = typesplit(T, Nothing)
2020
promote_rule(T::Type{Nothing}, S::Type) = Union{S, Nothing}
2121
function promote_rule(T::Type{>:Nothing}, S::Type)
2222
R = nonnothingtype(T)

stdlib/Test/src/Test.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import Distributed: myid
2828
using Random
2929
using Random: AbstractRNG, default_rng
3030
using InteractiveUtils: gen_call_with_extracted_types
31-
using Core.Compiler: typesubtract
31+
using Base: typesplit
3232

3333
const DISPLAY_FAILED = (
3434
:isequal,
@@ -1393,7 +1393,7 @@ function _inferred(ex, mod, allow = :(Union{}))
13931393
end)
13941394
@assert length(inftypes) == 1
13951395
rettype = result isa Type ? Type{result} : typeof(result)
1396-
rettype <: allow || rettype == typesubtract(inftypes[1], allow) || error("return type $rettype does not match inferred return type $(inftypes[1])")
1396+
rettype <: allow || rettype == typesplit(inftypes[1], allow) || error("return type $rettype does not match inferred return type $(inftypes[1])")
13971397
result
13981398
end
13991399
end)

test/compiler/inference.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2592,6 +2592,60 @@ end
25922592

25932593
@test map(>:, [Int], [Int]) == [true]
25942594

2595+
# issue 35566
2596+
module Issue35566
2597+
function step(acc, x)
2598+
xs, = acc
2599+
y = x > 0.0 ? x : missing
2600+
if y isa eltype(xs)
2601+
ys = push!(xs, y)
2602+
else
2603+
ys = vcat(xs, [y])
2604+
end
2605+
return (ys,)
2606+
end
2607+
2608+
function probe(y)
2609+
if y isa Tuple{Vector{Missing}}
2610+
return Val(:missing)
2611+
else
2612+
return Val(:expected)
2613+
end
2614+
end
2615+
2616+
function _foldl_iter(rf, val::T, iter, state) where {T}
2617+
while true
2618+
ret = iterate(iter, state)
2619+
ret === nothing && break
2620+
x, state = ret
2621+
y = rf(val, x)
2622+
if y isa T
2623+
val = y
2624+
else
2625+
return probe(y)
2626+
end
2627+
end
2628+
return Val(:expected)
2629+
end
2630+
2631+
f() = _foldl_iter(step, (Missing[],), [0.0], 1)
2632+
end
2633+
@test Core.Compiler.typesubtract(Tuple{Union{Int,Char}}, Tuple{Char}, 0) == Tuple{Int}
2634+
@test Core.Compiler.typesubtract(Tuple{Union{Int,Char}}, Tuple{Char}, 1) == Tuple{Int}
2635+
@test Core.Compiler.typesubtract(Tuple{Union{Int,Char}}, Tuple{Char}, 2) == Tuple{Int}
2636+
@test Core.Compiler.typesubtract(NTuple{3, Union{Int, Char}}, Tuple{Char, Any, Any}, 0) ==
2637+
Tuple{Int, Union{Char, Int}, Union{Char, Int}}
2638+
@test Core.Compiler.typesubtract(NTuple{3, Union{Int, Char}}, Tuple{Char, Any, Any}, 10) ==
2639+
Union{Tuple{Int, Char, Char}, Tuple{Int, Char, Int}, Tuple{Int, Int, Char}, Tuple{Int, Int, Int}}
2640+
@test Core.Compiler.typesubtract(NTuple{3, Union{Int, Char}}, NTuple{3, Char}, 0) ==
2641+
NTuple{3, Union{Int, Char}}
2642+
@test Core.Compiler.typesubtract(NTuple{3, Union{Int, Char}}, NTuple{3, Char}, 10) ==
2643+
Union{Tuple{Char, Char, Int}, Tuple{Char, Int, Char}, Tuple{Char, Int, Int}, Tuple{Int, Char, Char},
2644+
Tuple{Int, Char, Int}, Tuple{Int, Int, Char}, Tuple{Int, Int, Int}}
2645+
2646+
2647+
@test Base.return_types(Issue35566.f) == [Val{:expected}]
2648+
25952649
# constant prop through keyword arguments
25962650
_unstable_kw(;x=1,y=2) = x == 1 ? 0 : ""
25972651
_use_unstable_kw_1() = _unstable_kw(x = 2)

0 commit comments

Comments
 (0)