Skip to content

Commit 47d1f62

Browse files
committed
inference: also handle typesubtract for tuples where only one parameter remains
1 parent ad977cb commit 47d1f62

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

base/compiler/typeutils.jl

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,29 @@ function typesubtract(@nospecialize(a), @nospecialize(b), MAX_UNION_SPLITTING::I
7575
ub = unwrap_unionall(b)
7676
if ub isa DataType
7777
if a.name === ub.name === Tuple.name &&
78-
length(a.parameters) == length(ub.parameters) &&
79-
1 < unionsplitcost(a.parameters) <= MAX_UNION_SPLITTING
80-
ta = switchtupleunion(a)
81-
return typesubtract(Union{ta...}, b, 0)
78+
length(a.parameters) == length(ub.parameters)
79+
if 1 < unionsplitcost(a.parameters) <= MAX_UNION_SPLITTING
80+
ta = switchtupleunion(a)
81+
return typesubtract(Union{ta...}, b, 0)
82+
elseif b isa DataType
83+
# if exactly one element is not bottom after calling typesubtract
84+
# then the result is all of the elements as normal except that one
85+
notbottom = fill(false, length(a.parameters))
86+
for i = 1:length(notbottom)
87+
ap = a.parameters[i]
88+
bp = b.parameters[i]
89+
notbottom[i] = !(ap <: bp && isnotbrokensubtype(ap, bp))
90+
end
91+
let i = findfirst(notbottom)
92+
if i !== nothing && findnext(notbottom, i + 1) === nothing
93+
ta = collect(a.parameters)
94+
ap = a.parameters[i]
95+
bp = b.parameters[i]
96+
ta[i] = typesubtract(ap, bp, min(2, MAX_UNION_SPLITTING))
97+
return Tuple{ta...}
98+
end
99+
end
100+
end
82101
end
83102
end
84103
end

test/compiler/inference.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2630,8 +2630,20 @@ end
26302630

26312631
f() = _foldl_iter(step, (Missing[],), [0.0], 1)
26322632
end
2633-
@test Core.Compiler.typesubtract(Tuple{Union{Int,Char}}, Tuple{Char}, 1) == Tuple{Union{Int,Char}}
2633+
@test Core.Compiler.typesubtract(Tuple{Union{Int,Char}}, Tuple{Char}, 0) == Tuple{Int}
2634+
@test Core.Compiler.typesubtract(Tuple{Union{Int,Char}}, Tuple{Char}, 1) == Tuple{Int}
26342635
@test Core.Compiler.typesubtract(Tuple{Union{Int,Char}}, Tuple{Char}, 2) == Tuple{Int}
2636+
@test Core.Compiler.typesubtract(NTuple{3, Union{Int, Char}}, Tuple{Char, Any, Any}, 0) ==
2637+
Tuple{Int, Union{Char, Int}, Union{Char, Int}}
2638+
@test Core.Compiler.typesubtract(NTuple{3, Union{Int, Char}}, Tuple{Char, Any, Any}, 10) ==
2639+
Union{Tuple{Int, Char, Char}, Tuple{Int, Char, Int}, Tuple{Int, Int, Char}, Tuple{Int, Int, Int}}
2640+
@test Core.Compiler.typesubtract(NTuple{3, Union{Int, Char}}, NTuple{3, Char}, 0) ==
2641+
NTuple{3, Union{Int, Char}}
2642+
@test Core.Compiler.typesubtract(NTuple{3, Union{Int, Char}}, NTuple{3, Char}, 10) ==
2643+
Union{Tuple{Char, Char, Int}, Tuple{Char, Int, Char}, Tuple{Char, Int, Int}, Tuple{Int, Char, Char},
2644+
Tuple{Int, Char, Int}, Tuple{Int, Int, Char}, Tuple{Int, Int, Int}}
2645+
2646+
26352647
@test Base.return_types(Issue35566.f) == [Val{:expected}]
26362648

26372649
# constant prop through keyword arguments

0 commit comments

Comments
 (0)