Skip to content

Commit 5005108

Browse files
authored
Avoid eltype degrading to Union{} for empty map/broadcast (#664)
This reverts to using Core.Compiler.return_type for map/broadcast, but only in the very restricted case that the output container is completely empty. This is consistent with the way that return_type is used in Base for collect and broadcast for empty collections only.
1 parent febdc19 commit 5005108

File tree

6 files changed

+45
-18
lines changed

6 files changed

+45
-18
lines changed

src/broadcast.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,15 @@ scalar_getindex(x) = x
9797
scalar_getindex(x::Ref) = x[]
9898

9999
@generated function _broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize
100-
first_staticarray = 0
101-
for i = 1:length(a)
102-
if a[i] <: StaticArray
103-
first_staticarray = a[i]
104-
break
100+
first_staticarray = a[findfirst(ai -> ai <: StaticArray, a)]
101+
102+
if prod(newsize) == 0
103+
# Use inference to get eltype in empty case (see also comments in _map)
104+
eltys = [:(eltype(a[$i])) for i 1:length(a)]
105+
return quote
106+
@_inline_meta
107+
T = Core.Compiler.return_type(f, Tuple{$(eltys...)})
108+
@inbounds return similar_type($first_staticarray, T, Size(newsize))()
105109
end
106110
end
107111

src/mapreduce.jl

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,31 @@ end
1818
end
1919

2020
@generated function _map(f, a::AbstractArray...)
21-
i = findfirst(ai -> ai <: StaticArray, a)
22-
if i === nothing
21+
first_staticarray = findfirst(ai -> ai <: StaticArray, a)
22+
if first_staticarray === nothing
2323
return :(throw(ArgumentError("No StaticArray found in argument list")))
2424
end
2525
# Passing the Size as an argument to _map leads to inference issues when
2626
# recursively mapping over nested StaticArrays (see issue #593). Calling
27-
# Size in the generator here is valid because a[i] is known to be a
27+
# Size in the generator here is valid because a[first_staticarray] is known to be a
2828
# StaticArray for which the default Size method is correct. If wrapped
2929
# StaticArrays (with a custom Size method) are to be supported, this will
3030
# no longer be valid.
31-
S = Size(a[i])
31+
S = Size(a[first_staticarray])
32+
33+
if prod(S) == 0
34+
# In the empty case only, use inference to try figuring out a sensible
35+
# eltype, as is done in Base.collect and broadcast.
36+
# See https://github.com/JuliaArrays/StaticArrays.jl/issues/528
37+
eltys = [:(eltype(a[$i])) for i 1:length(a)]
38+
return quote
39+
@_inline_meta
40+
S = same_size(a...)
41+
T = Core.Compiler.return_type(f, Tuple{$(eltys...)})
42+
@inbounds return similar_type(a[$first_staticarray], T, S)()
43+
end
44+
end
45+
3246
exprs = Vector{Expr}(undef, prod(S))
3347
for i 1:prod(S)
3448
tmp = [:(a[$j][$i]) for j 1:length(a)]

test/broadcast.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,6 @@ end
109109
@test @inferred(v2 .- v1) === SVector(0, 2)
110110
@test @inferred(v1 .^ v2) === SVector(1, 16)
111111
@test @inferred(v2 .^ v1) === SVector(1, 16)
112-
# Issue #199: broadcast with empty SArray
113-
@test @inferred(SVector(1) .+ SVector{0,Int}()) === SVector{0,Union{}}()
114-
@test @inferred(SVector{0,Int}() .+ SVector(1)) === SVector{0,Union{}}()
115112
# Issue #200: broadcast with Adjoint
116113
@test @inferred(v1 .+ v2') === @SMatrix [2 5; 3 6]
117114
@test @inferred(v1 .+ transpose(v2)) === @SMatrix [2 5; 3 6]
@@ -142,6 +139,13 @@ end
142139
@test @inferred(zeros(SVector{0}) .+ zeros(SMatrix{0,2})) === zeros(SMatrix{0,2})
143140
m = zeros(MMatrix{0,2})
144141
@test @inferred(broadcast!(+, m, m, zeros(SVector{0}))) == zeros(SMatrix{0,2})
142+
# Issue #199: broadcast with empty SArray
143+
@test @inferred(SVector(1) .+ SVector{0,Int}()) === SVector{0,Int}()
144+
@test @inferred(SVector{0,Int}() .+ SVector(1.0)) === SVector{0,Float64}()
145+
# Issue #528
146+
@test @inferred(isapprox(SMatrix{3,0,Float64}(), SMatrix{3,0,Float64}()))
147+
@test @inferred(broadcast(length, SVector{0,String}())) === SVector{0,Int}()
148+
@test @inferred(broadcast(join, SVector{0,String}(), SVector{0,String}(), SVector{0,String}())) === SVector{0,String}()
145149
end
146150

147151
@testset "Mutating broadcast!" begin

test/linalg.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using StaticArrays, Test, LinearAlgebra
22

33
@testset "Linear algebra" begin
44

5-
@testset "SVector as a (mathematical) vector space" begin
5+
@testset "SArray as a (mathematical) vector space" begin
66
c = 2
77
v1 = @SVector [2,4,6,8]
88
v2 = @SVector [4,3,2,1]
@@ -14,6 +14,10 @@ using StaticArrays, Test, LinearAlgebra
1414
@test @inferred(v1 + v2) === @SVector [6, 7, 8, 9]
1515
@test @inferred(v1 - v2) === @SVector [-2, 1, 4, 7]
1616

17+
# #528 eltype with empty addition
18+
zm = zeros(SMatrix{3, 0, Float64})
19+
@test @inferred(zm + zm) === zm
20+
1721
# TODO Decide what to do about this stuff:
1822
#v3 = [2,4,6,8]
1923
#v4 = [4,3,2,1]

test/lu.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,5 @@ using StaticArrays, Test, LinearAlgebra
3636

3737
# decomposition is correct
3838
l_u = l*u
39-
if length(l_u) > 0 # Union{} element type breaks norm
40-
@test l*u a[p,:]
41-
else
42-
@test_broken l*u a[p,:]
43-
end
39+
@test l*u a[p,:]
4440
end

test/mapreduce.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ using Statistics: mean
2424
v3 = @SVector [1, 2, 3, 4]
2525
map!(+, mv3, v1, v2, v3)
2626
@test mv3 == @MVector [7, 9, 11, 13]
27+
28+
# Output eltype for empty cases #528
29+
@test @inferred(map(/, SVector{0,Int}(), SVector{0,Int}())) === SVector{0,Float64}()
30+
@test @inferred(map(+, SVector{0,Int}(), SVector{0,Float32}())) === SVector{0,Float32}()
31+
@test @inferred(map(length, SVector{0,String}())) === SVector{0,Int}()
2732
end
2833

2934
@testset "[map]reduce and [map]reducedim" begin

0 commit comments

Comments
 (0)