Skip to content

Commit 24b39ec

Browse files
vtjnashnalimilan
andauthored
Preserve non-concrete types in promote_typejoin (#37019)
It is useful to have `promote_typejoin(Union{Missing, Int}, Float64}` return `Union{Missing, Real}` instead of `Any`, in particular because `zero` is defined on the former but not on the latter. This allows `sum(skipmissing(::NamedTuple))` to work even when it contains only missing values. Fixes #35504, Closes #36939 Co-authored-by: Milan Bouchet-Valat <nalimilan@club.fr>
1 parent c3a63fc commit 24b39ec

File tree

5 files changed

+50
-31
lines changed

5 files changed

+50
-31
lines changed

base/promotion.jl

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -128,20 +128,12 @@ Compute a type that contains both `T` and `S`, which could be
128128
either a parent of both types, or a `Union` if appropriate.
129129
Falls back to [`typejoin`](@ref).
130130
"""
131-
promote_typejoin(@nospecialize(a), @nospecialize(b)) = _promote_typejoin(a, b)::Type
132-
_promote_typejoin(@nospecialize(a), @nospecialize(b)) = typejoin(a, b)
133-
_promote_typejoin(::Type{Nothing}, ::Type{T}) where {T} =
134-
isconcretetype(T) || T === Union{} ? Union{T, Nothing} : Any
135-
_promote_typejoin(::Type{T}, ::Type{Nothing}) where {T} =
136-
isconcretetype(T) || T === Union{} ? Union{T, Nothing} : Any
137-
_promote_typejoin(::Type{Missing}, ::Type{T}) where {T} =
138-
isconcretetype(T) || T === Union{} ? Union{T, Missing} : Any
139-
_promote_typejoin(::Type{T}, ::Type{Missing}) where {T} =
140-
isconcretetype(T) || T === Union{} ? Union{T, Missing} : Any
141-
_promote_typejoin(::Type{Nothing}, ::Type{Missing}) = Union{Nothing, Missing}
142-
_promote_typejoin(::Type{Missing}, ::Type{Nothing}) = Union{Nothing, Missing}
143-
_promote_typejoin(::Type{Nothing}, ::Type{Nothing}) = Nothing
144-
_promote_typejoin(::Type{Missing}, ::Type{Missing}) = Missing
131+
function promote_typejoin(@nospecialize(a), @nospecialize(b))
132+
c = typejoin(_promote_typesubtract(a), _promote_typesubtract(b))
133+
return Union{a, b, c}::Type
134+
end
135+
_promote_typesubtract(@nospecialize(a)) = Core.Compiler.typesubtract(a, Union{Nothing, Missing})
136+
145137

146138
# Returns length, isfixed
147139
function full_va_len(p)

base/tuple.jl

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -113,22 +113,32 @@ function eltype(t::Type{<:Tuple{Vararg{E}}}) where {E}
113113
end
114114
end
115115
eltype(t::Type{<:Tuple}) = _compute_eltype(t)
116-
function _compute_eltype(t::Type{<:Tuple})
116+
function _tuple_unique_fieldtypes(@nospecialize t)
117117
@_pure_meta
118-
@nospecialize t
119-
t isa Union && return promote_typejoin(eltype(t.a), eltype(t.b))
118+
types = IdSet()
119+
t´ = unwrap_unionall(t)
120120
# Given t = Tuple{Vararg{S}} where S<:Real, the various
121121
# unwrapping/wrapping/va-handling here will return Real
122-
= unwrap_unionall(t)
123-
# TODO: handle Union/UnionAll correctly here
124-
# For Tuple{T}, short-circuit promote_typejoin
125-
length(t´.parameters) == 1 && return rewrap_unionall(unwrapva(t´.parameters[1]), t)
126-
r = Union{}
127-
for ti in.parameters
128-
r = promote_typejoin(r, rewrap_unionall(unwrapva(ti), t))
129-
r === Any && break # if we've already reached Any, it can't widen any more
122+
if t isa Union
123+
union!(types, _tuple_unique_fieldtypes(rewrap_unionall(t´.a, t)))
124+
union!(types, _tuple_unique_fieldtypes(rewrap_unionall(t´.b, t)))
125+
else
126+
r = Union{}
127+
for ti in (t´::DataType).parameters
128+
r = push!(types, rewrap_unionall(unwrapva(ti), t))
129+
end
130+
end
131+
return Core.svec(types...)
132+
end
133+
function _compute_eltype(@nospecialize t)
134+
@_pure_meta # TODO: the compiler shouldn't need this
135+
types = _tuple_unique_fieldtypes(t)
136+
return afoldl(types...) do a, b
137+
# if we've already reached Any, it can't widen any more
138+
a === Any && return Any
139+
b === Any && return Any
140+
return promote_typejoin(a, b)
130141
end
131-
return r
132142
end
133143

134144
# version of tail that doesn't throw on empty tuples (used in array indexing)

test/core.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,21 @@ for T in (Nothing, Missing)
168168
@test Base.promote_typejoin(Int, T) === Union{Int, T}
169169
@test Base.promote_typejoin(T, String) === Union{T, String}
170170
@test Base.promote_typejoin(Vector{Int}, T) === Union{Vector{Int}, T}
171-
@test Base.promote_typejoin(Vector, T) === Any
172-
@test Base.promote_typejoin(Real, T) === Any
173-
@test Base.promote_typejoin(Int, String) === Any
174-
@test Base.promote_typejoin(Int, Union{Float64, T}) === Any
175-
@test Base.promote_typejoin(Int, Union{String, T}) === Any
171+
@test Base.promote_typejoin(Vector, T) === Union{Vector, T}
172+
@test Base.promote_typejoin(Real, T) === Union{Real, T}
173+
for U in (String, Float64)
174+
@test Base.promote_typejoin(Int, U) === typejoin(Int, U)
175+
@test Base.promote_typejoin(Int, Union{U, T}) === Union{typejoin(Int, U), T}
176+
@test Base.promote_typejoin(Union{Int, U}, T) === Union{Union{Int, U}, T}
177+
@test Base.promote_typejoin(Union{T, U}, Int) === Union{typejoin(Int, U), T}
178+
@test Base.promote_typejoin(Union{T, U}, Union{T, Int}) === Union{typejoin(Int, U), T}
179+
@test Base.promote_typejoin(Union{T, U}, Union{Missing, Int}) ===
180+
Union{typejoin(Int, U), T, Missing}
181+
@test Base.promote_typejoin(Union{T, U}, Union{Nothing, Int}) ===
182+
Union{typejoin(Int, U), T, Nothing}
183+
@test Base.promote_typejoin(Union{T, Nothing, U}, Union{Nothing, Missing, Int}) ===
184+
Union{typejoin(Int, U), T, Nothing, Missing}
185+
end
176186
@test Base.promote_typejoin(T, Union{}) === T
177187
@test Base.promote_typejoin(Union{}, T) === T
178188
end

test/missing.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,11 @@ end
491491
@test_throws ArgumentError reduce(x -> x/2, itr)
492492
@test_throws ArgumentError mapreduce(x -> x/2, +, itr)
493493
end
494+
495+
# issue #35504
496+
nt = NamedTuple{(:x, :y),Tuple{Union{Missing, Int},Union{Missing, Float64}}}(
497+
(missing, missing))
498+
@test sum(skipmissing(nt)) === 0
494499
end
495500

496501
@testset "filter" begin

test/namedtuple.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ end
6969
NamedTuple{(:a,), Tuple{Union{Int,Nothing}}}((2,))
7070

7171
@test eltype((a=[1,2], b=[3,4])) === Vector{Int}
72+
@test eltype(NamedTuple{(:x, :y),Tuple{Union{Missing, Int},Union{Missing, Float64}}}(
73+
(missing, missing))) === Union{Real, Missing}
7274

7375
@test Tuple((a=[1,2], b=[3,4])) == ([1,2], [3,4])
7476
@test Tuple(NamedTuple()) === ()

0 commit comments

Comments
 (0)