Skip to content

Commit 2a36f83

Browse files
authored
inference: remove union-split limit for linear signatures (#37378)
This size limit should be already be imposed elsewhere (tmerge), and should not actually add cost to perform the union/tuple-switching when there is no cartesian product to consider. This permits users to explicitly demand larger union products (for example, with type-asserts or field types) and still expect to get reliable union-splitting at any size in single-dispatch sites.
1 parent b0cfb43 commit 2a36f83

File tree

4 files changed

+31
-15
lines changed

4 files changed

+31
-15
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
2424
end
2525
valid_worlds = WorldRange()
2626
atype_params = unwrap_unionall(atype).parameters
27-
splitunions = 1 < countunionsplit(atype_params) <= InferenceParams(interp).MAX_UNION_SPLITTING
27+
splitunions = 1 < unionsplitcost(atype_params) <= InferenceParams(interp).MAX_UNION_SPLITTING
2828
mts = Core.MethodTable[]
2929
fullmatch = Bool[]
3030
if splitunions
@@ -113,7 +113,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
113113
sigtuple = unwrap_unionall(sig)::DataType
114114
splitunions = false
115115
this_rt = Bottom
116-
# TODO: splitunions = 1 < countunionsplit(sigtuple.parameters) * napplicable <= InferenceParams(interp).MAX_UNION_SPLITTING
116+
# TODO: splitunions = 1 < unionsplitcost(sigtuple.parameters) * napplicable <= InferenceParams(interp).MAX_UNION_SPLITTING
117117
# currently this triggers a bug in inference recursion detection
118118
if splitunions
119119
splitsigs = switchtupleunion(sig)
@@ -654,7 +654,7 @@ function abstract_apply(interp::AbstractInterpreter, @nospecialize(itft), @nospe
654654
end
655655
res = Union{}
656656
nargs = length(aargtypes)
657-
splitunions = 1 < countunionsplit(aargtypes) <= InferenceParams(interp).MAX_APPLY_UNION_ENUM
657+
splitunions = 1 < unionsplitcost(aargtypes) <= InferenceParams(interp).MAX_APPLY_UNION_ENUM
658658
ctypes = Any[Any[aft]]
659659
infos = [Union{Nothing, AbstractIterationInfo}[]]
660660
for i = 1:nargs

base/compiler/ssair/inlining.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1173,7 +1173,7 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
11731173
continue
11741174
end
11751175

1176-
nu = countunionsplit(sig.atypes)
1176+
nu = unionsplitcost(sig.atypes)
11771177
if nu == 1 || nu > state.params.MAX_UNION_SPLITTING
11781178
if !isa(info, MethodMatchInfo)
11791179
if state.method_table === nothing

base/compiler/typeutils.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,23 @@ function tuple_tail_elem(@nospecialize(init), ct::Vector{Any})
103103
return Vararg{widenconst(t)}
104104
end
105105

106-
function countunionsplit(atypes::Union{SimpleVector,Vector{Any}})
106+
# Gives a cost function over the effort to switch a tuple-union representation
107+
# as a cartesian product, relative to the size of the original representation.
108+
# Thus, we count the longest element as being roughly invariant to being inside
109+
# or outside of the Tuple/Union nesting, though somewhat more expensive to be
110+
# outside than inside because the representation is larger (because and it
111+
# informs the callee whether any splitting is possible).
112+
function unionsplitcost(atypes::Union{SimpleVector,Vector{Any}})
107113
nu = 1
114+
max = 2
108115
for ti in atypes
109116
if isa(ti, Union)
110-
nu, ovf = Core.Intrinsics.checked_smul_int(nu, unionlen(ti::Union))
111-
if ovf
112-
return typemax(Int)
117+
nti = unionlen(ti)
118+
if nti > max
119+
max, nti = nti, max
113120
end
121+
nu, ovf = Core.Intrinsics.checked_smul_int(nu, nti)
122+
ovf && return typemax(Int)
114123
end
115124
end
116125
return nu

test/compiler/inference.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -769,20 +769,23 @@ g11015(::Type{Bool}, ::Bool) = 2.0
769769

770770
# better inference of apply (#20343)
771771
f20343(::String, ::Int) = 1
772-
f20343(::Int, ::String, ::Int, ::Int) = 1
773-
f20343(::Int, ::Int, ::String, ::Int, ::Int, ::Int) = 1
774-
f20343(::Union{Int,String}...) = Int8(1)
772+
f20343(::Int, ::String, ::Int, ::Int) = 2
773+
f20343(::Int, ::Int, ::String, ::Int, ::Int, ::Int) = 3
774+
f20343(::Int, ::Int, ::Int, ::String, ::Int, ::Int, ::Int, ::Int, ::Int, ::Int, ::Int, ::Int) = 4
775+
f20343(::Union{Int,String}...) = Int8(5)
775776
f20343(::Any...) = "no"
776777
function g20343()
777778
n = rand(1:3)
778-
i = ntuple(i->n==i ? "" : 0, 2n)::Union{Tuple{String,Int},Tuple{Int,String,Int,Int},Tuple{Int,Int,String,Int,Int,Int}}
779+
T = Union{Tuple{String, Int}, Tuple{Int, String, Int, Int}, Tuple{Int, Int, String, Int, Int, Int}}
780+
i = ntuple(i -> n == i ? "" : 0, 2n)::T
779781
f20343(i...)
780782
end
781783
@test Base.return_types(g20343, ()) == [Int]
782784
function h20343()
783785
n = rand(1:3)
784-
i = ntuple(i->n==i ? "" : 0, 3)::Union{Tuple{String,Int,Int},Tuple{Int,String,Int},Tuple{Int,Int,String}}
785-
f20343(i..., i...)
786+
T = Union{Tuple{String, Int, Int}, Tuple{Int, String, Int}, Tuple{Int, Int, String}}
787+
i = ntuple(i -> n == i ? "" : 0, 3)::T
788+
f20343(i..., i..., i..., i...)
786789
end
787790
@test Base.return_types(h20343, ()) == [Union{Int8, Int}]
788791
function i20343()
@@ -2099,7 +2102,11 @@ end
20992102
# issue #28356
21002103
# unit test to make sure countunionsplit overflows gracefully
21012104
# we don't care what number is returned as long as it's large
2102-
@test Core.Compiler.countunionsplit(Any[Union{Int32,Int64} for i=1:80]) > 100000
2105+
@test Core.Compiler.unionsplitcost(Any[Union{Int32, Int64} for i=1:80]) > 100000
2106+
@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32, Int64}]) == 2
2107+
@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32, Int64}, Union{Int8, Int16, Int32, Int64}, Int8]) == 8
2108+
@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32, Int64}, Union{Int8, Int16, Int32}, Int8]) == 6
2109+
@test Core.Compiler.unionsplitcost(Any[Union{Int8, Int16, Int32}, Union{Int8, Int16, Int32, Int64}, Int8]) == 6
21032110

21042111
# make sure compiler doesn't hang in union splitting
21052112

0 commit comments

Comments
 (0)