Skip to content

Commit 2744c81

Browse files
Fix zero(SVector{3,Any}) etc. with multiple dispatch (#1129)
* fix zero(SVector{3,Any}) etc. with multiple dispatch * replace `SA{Float64}` with `Base.typeintersect(SA, AbstractArray{Float64})` * update tests for `ones` and `zeros` * update tests for `fill` * replace `T` with `SA` * replace `SA{U}` with `Base.typeintersect(SA, AbstractArray{U})` * remove unnecessary methods for `zeros` and `ones`
1 parent d1e595a commit 2744c81

File tree

4 files changed

+142
-114
lines changed

4 files changed

+142
-114
lines changed

src/MVector.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
# Some more advanced constructor-like functions
2-
@inline zeros(::Type{MVector{N}}) where {N} = zeros(MVector{N,Float64})
3-
@inline ones(::Type{MVector{N}}) where {N} = ones(MVector{N,Float64})
4-
51
#####################
62
## MVector methods ##
73
#####################

src/SVector.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
2-
# Some more advanced constructor-like functions
3-
@inline zeros(::Type{SVector{N}}) where {N} = zeros(SVector{N,Float64})
4-
@inline ones(::Type{SVector{N}}) where {N} = ones(SVector{N,Float64})
5-
61
#####################
72
## SVector methods ##
83
#####################

src/arraymath.jl

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
@inline zeros(::Type{SA}) where {SA <: StaticArray} = _zeros(Size(SA), SA)
1+
@inline zeros(::Type{SA}) where {SA <: StaticArray{<:Tuple}} = zeros(Base.typeintersect(SA, AbstractArray{Float64}))
2+
@inline zeros(::Type{SA}) where {SA <: StaticArray{<:Tuple, T}} where T = _zeros(Size(SA), SA)
23
@generated function _zeros(::Size{s}, ::Type{SA}) where {s, SA <: StaticArray}
34
T = eltype(SA)
4-
if T == Any
5-
T = Float64
6-
end
75
v = [:(zero($T)) for i = 1:prod(s)]
86
if SA <: SArray
97
SA = SArray{Tuple{s...}, T, length(s), prod(s)}
@@ -18,12 +16,10 @@
1816
end
1917
end
2018

21-
@inline ones(::Type{SA}) where {SA <: StaticArray} = _ones(Size(SA), SA)
19+
@inline ones(::Type{SA}) where {SA <: StaticArray{<:Tuple}} = ones(Base.typeintersect(SA, AbstractArray{Float64}))
20+
@inline ones(::Type{SA}) where {SA <: StaticArray{<:Tuple, T}} where T = _ones(Size(SA), SA)
2221
@generated function _ones(::Size{s}, ::Type{SA}) where {s, SA <: StaticArray}
2322
T = eltype(SA)
24-
if T == Any
25-
T = Float64
26-
end
2723
v = [:(one($T)) for i = 1:prod(s)]
2824
if SA <: SArray
2925
SA = SArray{Tuple{s...}, T, length(s), prod(s)}
@@ -38,13 +34,11 @@ end
3834
end
3935
end
4036

41-
@inline fill(val, ::SA) where {SA <: StaticArray} = _fill(val, Size(SA), SA)
42-
@inline fill(val, ::Type{SA}) where {SA <: StaticArray} = _fill(val, Size(SA), SA)
43-
@generated function _fill(val::U, ::Size{s}, ::Type{SA}) where {U, s, SA <: StaticArray}
37+
@inline fill(val, ::SA) where {SA <: StaticArray{<:Tuple}} = _fill(val, Size(SA), SA)
38+
@inline fill(val::U, ::Type{SA}) where {SA <: StaticArray} where U = fill(val, Base.typeintersect(SA, AbstractArray{U}))
39+
@inline fill(val, ::Type{SA}) where {SA <: StaticArray{<:Tuple, T}} where T = _fill(val, Size(SA), SA)
40+
@generated function _fill(val, ::Size{s}, ::Type{SA}) where {s, SA <: StaticArray}
4441
T = eltype(SA)
45-
if T == Any
46-
T = U
47-
end
4842
v = [:val for i = 1:prod(s)]
4943
if SA <: SArray
5044
SA = SArray{Tuple{s...}, T, length(s), prod(s)}

test/arraymath.jl

Lines changed: 134 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -37,66 +37,109 @@ import StaticArrays.arithmetic_closure
3737
end
3838

3939
@testset "ones()" begin
40-
for T in (SVector, MVector, SizedVector)
41-
m = @inferred ones(T{3, Float64})
40+
for SA in (SVector, MVector, SizedVector)
41+
# Float64
42+
m = @inferred ones(SA{3, Float64})
4243
@test m == [1.0, 1.0, 1.0]
43-
@test m isa T{3, Float64}
44-
m = @inferred ones(T{3, Int})
44+
@test m isa SA{3, Float64}
45+
# Int
46+
m = @inferred ones(SA{3, Int})
4547
@test m == [1, 1, 1]
46-
@test m isa T{3, Int}
47-
m = @inferred ones(T{3})
48+
@test m isa SA{3, Int}
49+
# Unspecified
50+
m = @inferred ones(SA{3})
4851
@test m == [1.0, 1.0, 1.0]
49-
@test m isa T{3}
50-
m = @inferred ones(T{0, Float64})
52+
@test m isa SA{3}
53+
# Float64
54+
m = @inferred ones(SA{0, Float64})
5155
@test m == Float64[]
52-
@test m isa T{0, Float64}
53-
m = @inferred ones(T{0, Int})
56+
@test m isa SA{0, Float64}
57+
# Int
58+
m = @inferred ones(SA{0, Int})
5459
@test m == Int[]
55-
@test m isa T{0, Int}
56-
m = @inferred ones(T{0})
60+
@test m isa SA{0, Int}
61+
# Unspecified
62+
m = @inferred ones(SA{0})
5763
@test m == Float64[]
58-
@test m isa T{0}
64+
@test m isa SA{0}
65+
# Any
66+
@test_throws MethodError ones(SA{3, Any})
67+
@test ones(SA{0, Any}) isa SA{0, Any}
5968
end
6069
end
6170

6271
@testset "zero()" begin
63-
for T in (SVector, MVector, SizedVector)
64-
m = @inferred zero(T{3, Float64})
72+
for SA in (SVector, MVector, SizedVector)
73+
# Float64
74+
m = @inferred zero(SA{3, Float64})
6575
@test m == [0.0, 0.0, 0.0]
66-
@test m isa T{3, Float64}
67-
m = @inferred zero(T{3, Int})
76+
@test m isa SA{3, Float64}
77+
# Int
78+
m = @inferred zero(SA{3, Int})
6879
@test m == [0, 0, 0]
69-
@test m isa T{3, Int}
70-
m = @inferred zero(T{3})
80+
@test m isa SA{3, Int}
81+
# Unspecified
82+
m = @inferred zero(SA{3})
7183
@test m == [0.0, 0.0, 0.0]
72-
@test m isa T{3}
73-
m = @inferred zero(T{0, Float64})
84+
@test m isa SA{3}
85+
# Float64 (zero-element)
86+
m = @inferred zero(SA{0, Float64})
7487
@test m == Float64[]
75-
@test m isa T{0, Float64}
76-
m = @inferred zero(T{0, Int})
88+
@test m isa SA{0, Float64}
89+
# Int (zero-element)
90+
m = @inferred zero(SA{0, Int})
7791
@test m == Int[]
78-
@test m isa T{0, Int}
79-
m = @inferred zero(T{0})
92+
@test m isa SA{0, Int}
93+
# Unspecified (zero-element)
94+
m = @inferred zero(SA{0})
8095
@test m == Float64[]
81-
@test m isa T{0}
96+
@test m isa SA{0}
97+
# Any
98+
@test_throws MethodError zeros(SA{3, Any})
99+
@test zeros(SA{0, Any}) isa SA{0, Any}
82100
end
83101
end
84102

85103
@testset "fill()" begin
86104
@test @allocated(fill(0., SMatrix{1, 16, Float64})) == 0 # #81
87105
@test @allocated(fill(0., SMatrix{0, 5, Float64})) == 0
88106

89-
for T in (SMatrix, MMatrix, SizedMatrix)
90-
m = @inferred(fill(3., T{4, 16, Float64}))
91-
@test m isa T{4, 16, Float64}
92-
@test all(m .== 3.)
93-
m = @inferred(fill(3., T{0, 5, Float64}))
94-
@test m isa T{0, 5, Float64}
95-
m = @inferred(fill(3, T{4, 16, Float64}))
96-
@test m isa T{4, 16, Float64}
97-
@test all(m .== 3.)
98-
m = @inferred(fill(3, T{0, 5, Float64}))
99-
@test m isa T{0, 5, Float64}
107+
for SA in (SMatrix, MMatrix, SizedMatrix)
108+
for T in (Float64, Int, Any)
109+
# Float64 -> T
110+
m = @inferred(fill(3.0, SA{4, 16, T}))
111+
@test m isa SA{4, 16, T}
112+
@test all(m .== 3)
113+
# Float64 -> T (zero-element)
114+
m = @inferred(fill(3.0, SA{0, 5, T}))
115+
@test m isa SA{0, 5, T}
116+
@test all(m .== 3)
117+
# Int -> T
118+
m = @inferred(fill(3, SA{4, 16, T}))
119+
@test m isa SA{4, 16, T}
120+
@test all(m .== 3)
121+
# Int -> T (zero-element)
122+
m = @inferred(fill(3, SA{0, 5, T}))
123+
@test m isa SA{0, 5, T}
124+
@test all(m .== 3)
125+
end
126+
127+
# Float64 -> Unspecified
128+
m = @inferred(fill(3.0, SA{4, 16}))
129+
@test m isa SA{4, 16, Float64}
130+
@test all(m .== 3)
131+
# Float64 -> Unspecified (zero-element)
132+
m = @inferred(fill(3.0, SA{0, 5}))
133+
@test m isa SA{0, 5, Float64}
134+
@test all(m .== 3)
135+
# Int -> Unspecified
136+
m = @inferred(fill(3, SA{4, 16}))
137+
@test m isa SA{4, 16, Int}
138+
@test all(m .== 3)
139+
# Int -> Unspecified (zero-element)
140+
m = @inferred(fill(3, SA{0, 5}))
141+
@test m isa SA{0, 5, Int}
142+
@test all(m .== 3)
100143
end
101144
end
102145

@@ -119,21 +162,21 @@ import StaticArrays.arithmetic_closure
119162
m = rand(1:1, SVector{3})
120163
@test rand(m) == 1
121164

122-
for T in (SVector, MVector, SizedVector)
123-
v1 = rand(T{3})
124-
@test v1 isa T{3, Float64}
165+
for SA in (SVector, MVector, SizedVector)
166+
v1 = rand(SA{3})
167+
@test v1 isa SA{3, Float64}
125168
@test all(0 .< v1 .< 1)
126169

127-
v2 = rand(T{0})
128-
@test v2 isa T{0, Float64}
170+
v2 = rand(SA{0})
171+
@test v2 isa SA{0, Float64}
129172
@test all(0 .< v2 .< 1)
130173

131-
v3 = rand(T{3, Float32})
132-
@test v3 isa T{3, Float32}
174+
v3 = rand(SA{3, Float32})
175+
@test v3 isa SA{3, Float32}
133176
@test all(0 .< v3 .< 1)
134177

135-
v4 = rand(T{0, Float32})
136-
@test v4 isa T{0, Float32}
178+
v4 = rand(SA{0, Float32})
179+
@test v4 isa SA{0, Float32}
137180
@test all(0 .< v4 .< 1)
138181
end
139182
end
@@ -148,105 +191,105 @@ import StaticArrays.arithmetic_closure
148191
check = ((m .>= 1) .& (m .<= 2))
149192
@test all(check)
150193

151-
for T in (MVector, SizedVector)
152-
v1 = rand(T{3})
194+
for SA in (MVector, SizedVector)
195+
v1 = rand(SA{3})
153196
rand!(v1)
154-
@test v1 isa T{3, Float64}
197+
@test v1 isa SA{3, Float64}
155198
@test all(0 .< v1 .< 1)
156199

157-
v2 = rand(T{0})
200+
v2 = rand(SA{0})
158201
rand!(v2)
159-
@test v2 isa T{0, Float64}
202+
@test v2 isa SA{0, Float64}
160203
@test all(0 .< v2 .< 1)
161204

162-
v3 = rand(T{3, Float32})
205+
v3 = rand(SA{3, Float32})
163206
rand!(v3)
164-
@test v3 isa T{3, Float32}
207+
@test v3 isa SA{3, Float32}
165208
@test all(0 .< v3 .< 1)
166209

167-
v4 = rand(T{0, Float32})
210+
v4 = rand(SA{0, Float32})
168211
rand!(v4)
169-
@test v4 isa T{0, Float32}
212+
@test v4 isa SA{0, Float32}
170213
@test all(0 .< v4 .< 1)
171214
end
172215
end
173216

174217
@testset "randn()" begin
175-
for T in (SVector, MVector, SizedVector)
176-
v1 = randn(T{3})
177-
@test v1 isa T{3, Float64}
218+
for SA in (SVector, MVector, SizedVector)
219+
v1 = randn(SA{3})
220+
@test v1 isa SA{3, Float64}
178221

179-
v2 = randn(T{0})
180-
@test v2 isa T{0, Float64}
222+
v2 = randn(SA{0})
223+
@test v2 isa SA{0, Float64}
181224

182-
v3 = randn(T{3, Float32})
183-
@test v3 isa T{3, Float32}
225+
v3 = randn(SA{3, Float32})
226+
@test v3 isa SA{3, Float32}
184227

185-
v4 = randn(T{0, Float32})
186-
@test v4 isa T{0, Float32}
228+
v4 = randn(SA{0, Float32})
229+
@test v4 isa SA{0, Float32}
187230
end
188231
end
189232

190233
@testset "randn!()" begin
191-
for T in (MVector, SizedVector)
192-
v1 = randn(T{3})
234+
for SA in (MVector, SizedVector)
235+
v1 = randn(SA{3})
193236
randn!(v1)
194-
@test v1 isa T{3, Float64}
237+
@test v1 isa SA{3, Float64}
195238

196-
v2 = randn(T{0})
239+
v2 = randn(SA{0})
197240
randn!(v2)
198-
@test v2 isa T{0, Float64}
241+
@test v2 isa SA{0, Float64}
199242

200-
v3 = randn(T{3, Float32})
243+
v3 = randn(SA{3, Float32})
201244
randn!(v3)
202-
@test v3 isa T{3, Float32}
245+
@test v3 isa SA{3, Float32}
203246

204-
v4 = randn(T{0, Float32})
247+
v4 = randn(SA{0, Float32})
205248
randn!(v4)
206-
@test v4 isa T{0, Float32}
249+
@test v4 isa SA{0, Float32}
207250
end
208251
end
209252

210253
@testset "randexp()" begin
211-
for T in (SVector, MVector, SizedVector)
212-
v1 = randexp(T{3})
213-
@test v1 isa T{3, Float64}
254+
for SA in (SVector, MVector, SizedVector)
255+
v1 = randexp(SA{3})
256+
@test v1 isa SA{3, Float64}
214257
@test all(0 .< v1)
215258

216-
v2 = randexp(T{0})
217-
@test v2 isa T{0, Float64}
259+
v2 = randexp(SA{0})
260+
@test v2 isa SA{0, Float64}
218261
@test all(0 .< v2)
219262

220-
v3 = randexp(T{3, Float32})
221-
@test v3 isa T{3, Float32}
263+
v3 = randexp(SA{3, Float32})
264+
@test v3 isa SA{3, Float32}
222265
@test all(0 .< v3)
223266

224-
v4 = randexp(T{0, Float32})
225-
@test v4 isa T{0, Float32}
267+
v4 = randexp(SA{0, Float32})
268+
@test v4 isa SA{0, Float32}
226269
@test all(0 .< v4)
227270
end
228271
end
229272

230273
@testset "randexp!()" begin
231-
for T in (MVector, SizedVector)
232-
v1 = randexp(T{3})
274+
for SA in (MVector, SizedVector)
275+
v1 = randexp(SA{3})
233276
randexp!(v1)
234-
@test v1 isa T{3, Float64}
277+
@test v1 isa SA{3, Float64}
235278
@test all(0 .< v1)
236279

237-
v2 = randexp(T{0})
280+
v2 = randexp(SA{0})
238281
randexp!(v2)
239-
@test v2 isa T{0, Float64}
282+
@test v2 isa SA{0, Float64}
240283
@test all(0 .< v2)
241284

242-
v3 = randexp(T{3, Float32})
285+
v3 = randexp(SA{3, Float32})
243286
randexp!(v3)
244-
@test v3 isa T{3, Float32}
287+
@test v3 isa SA{3, Float32}
245288
@test all(0 .< v3)
246289

247-
v4 = randexp(T{0, Float32})
290+
v4 = randexp(SA{0, Float32})
248291
randexp!(v4)
249-
@test v4 isa T{0, Float32}
292+
@test v4 isa SA{0, Float32}
250293
@test all(0 .< v4)
251294
end
252295
end

0 commit comments

Comments
 (0)