Skip to content

Commit b9ea6ba

Browse files
authored
Merge pull request #66 from JuliaArrays/use-size
Use `Size` more consistently
2 parents 1f035b4 + 9aeb0d1 commit b9ea6ba

File tree

12 files changed

+783
-267
lines changed

12 files changed

+783
-267
lines changed

perf/bench9.txt

Lines changed: 342 additions & 0 deletions
Large diffs are not rendered by default.

perf/benchmark3.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
using StaticArrays
2+
using BenchmarkTools
3+
4+
import BenchmarkTools: prettytime, prettymemory
5+
6+
@noinline plus(a,b) = a+b
7+
@noinline plus!(c,a,b) = broadcast!(+, c, a, b)
8+
9+
@noinline mul(a,b) = a*b
10+
@noinline mul!(c,a,b) = A_mul_B!(c, a, b)
11+
12+
13+
for T [Int64, Float64]
14+
for N [1,2,4,8,16,32,64,128,256]
15+
println("=====================================================================")
16+
println(" Vectors of length ", N, " and eltype ", T)
17+
println("=====================================================================")
18+
immutables = [rand(SVector{N,T})]
19+
mutables = [rand(T,N), rand(MVector{N,T}), Size(N)(rand(T,N))]
20+
instances = vcat(immutables, mutables)
21+
22+
namelengths = [length(string(typeof(v).name.name)) for v instances]
23+
maxnamelength = maximum(namelengths)
24+
25+
for v instances
26+
result = mean(@benchmark plus($(copy(v)), $(copy(v))))
27+
padding = maxnamelength - length(string(typeof(v).name.name))
28+
println(typeof(v).name.name, ":", " " ^ padding, " v3 = v1 + v2 takes ", prettytime(time(result)), ", ", prettymemory(memory(result)), " (GC ", prettytime(gctime(result)) , ")")
29+
end
30+
31+
println()
32+
33+
for v mutables
34+
result = mean(@benchmark plus!($(copy(v)), $(copy(v)), $(copy(v))))
35+
padding = maxnamelength - length(string(typeof(v).name.name))
36+
println(typeof(v).name.name, ":", " " ^ padding, " v3 .= +.(v1, v2) takes ", prettytime(time(result)), ", ", prettymemory(memory(result)), " (GC ", prettytime(gctime(result)) , ")")
37+
end
38+
39+
println()
40+
41+
if N > 16
42+
continue
43+
end
44+
println("=====================================================================")
45+
println(" Matrices of size ", N, "×", N, " and eltype ", T)
46+
println("=====================================================================")
47+
immutables = [rand(SMatrix{N,N,T})]
48+
mutables = [rand(T,N,N), rand(MMatrix{N,N,T}), Size(N,N)(rand(T,N,N))]
49+
instances = vcat(immutables, mutables)
50+
51+
namelengths = [length(string(typeof(v).name.name)) for v instances]
52+
maxnamelength = maximum(namelengths)
53+
54+
for m instances
55+
result = mean(@benchmark mul($(copy(m)), $(copy(m))))
56+
padding = maxnamelength - length(string(typeof(m).name.name))
57+
println(typeof(m).name.name, ":", " " ^ padding, " m3 = m1 * m2 takes ", prettytime(time(result)), ", ", prettymemory(memory(result)), " (GC ", prettytime(gctime(result)) , ")")
58+
end
59+
60+
println()
61+
62+
for m mutables
63+
result = mean(@benchmark mul!($(copy(m)), $(copy(m)), $(copy(m))))
64+
padding = maxnamelength - length(string(typeof(m).name.name))
65+
println(typeof(m).name.name, ":", " " ^ padding, " A_mul_B!(m3, m1, m2) takes ", prettytime(time(result)), ", ", prettymemory(memory(result)), " (GC ", prettytime(gctime(result)) , ")")
66+
end
67+
68+
println()
69+
70+
end
71+
end

src/SizedArray.jl

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,54 @@ array may be reshaped.
1212
immutable SizedArray{S,T,N,M} <: StaticArray{T,N}
1313
data::Array{T,M}
1414

15-
function SizedArray(a)
15+
function SizedArray(a::Array)
1616
if length(a) != prod(S)
1717
error("Dimensions $(size(a)) don't match static size $S")
1818
end
1919
new(a)
2020
end
21+
22+
function SizedArray()
23+
new(Array{T,M}(S))
24+
end
2125
end
2226

2327
@inline (::Type{SizedArray{S,T,N}}){S,T,N,M}(a::Array{T,M}) = SizedArray{S,T,N,M}(a)
2428
@inline (::Type{SizedArray{S,T}}){S,T,M}(a::Array{T,M}) = SizedArray{S,T,_ndims(S),M}(a)
2529
@inline (::Type{SizedArray{S}}){S,T,M}(a::Array{T,M}) = SizedArray{S,T,_ndims(S),M}(a)
2630

31+
@inline (::Type{SizedArray{S,T,N}}){S,T,N}() = SizedArray{S,T,N,N}()
32+
@inline (::Type{SizedArray{S,T}}){S,T}() = SizedArray{S,T,_ndims(S),_ndims(S)}()
33+
34+
@generated function (::Type{SizedArray{S,T,N,M}}){S,T,N,M,L}(x::NTuple{L})
35+
if L != prod(S)
36+
error("Dimension mismatch")
37+
end
38+
exprs = [:(a[$i] = x[$i]) for i = 1:L]
39+
return quote
40+
$(Expr(:meta, :inline))
41+
a = SizedArray{S,T,N,M}()
42+
@inbounds $(Expr(:block, exprs...))
43+
return a
44+
end
45+
end
46+
47+
@inline (::Type{SizedArray{S,T,N}}){S,T,N}(x::Tuple) = SizedArray{S,T,N,N}(x)
48+
@inline (::Type{SizedArray{S,T}}){S,T}(x::Tuple) = SizedArray{S,T,_dims(S),_dims(S)}(x)
49+
@inline (::Type{SizedArray{S}}){S,T,L}(x::NTuple{L,T}) = SizedArray{S,T,_dims(S),_dims(S)}(x)
50+
2751
# Overide some problematic default behaviour
2852
@inline convert{SA<:SizedArray}(::Type{SA}, sa::SizedArray) = SA(sa.data)
2953

54+
# Back to Array (unfortunately need both convert and construct to overide other methods)
55+
@inline (::Type{Array})(sa::SizedArray) = sa.data
56+
@inline (::Type{Array{T}}){T,S}(sa::SizedArray{S,T}) = sa.data
57+
@inline (::Type{Array{T,N}}){T,S,N}(sa::SizedArray{S,T,N}) = sa.data
58+
59+
@inline convert(::Type{Array}, sa::SizedArray) = sa.data
60+
@inline convert{T,S}(::Type{Array{T}}, sa::SizedArray{S,T}) = sa.data
61+
@inline convert{T,S,N}(::Type{Array{T,N}}, sa::SizedArray{S,T,N}) = sa.data
62+
3063
@pure _ndims{N}(::NTuple{N,Int}) = N
3164

3265
@pure size{S}(::Type{SizedArray{S}}) = S
@@ -38,10 +71,18 @@ end
3871
@propagate_inbounds setindex!(a::SizedArray, v, i::Int) = setindex!(a.data, v, i)
3972

4073
typealias SizedVector{S,T,M} SizedArray{S,T,1,M}
74+
@pure size{S}(::Type{SizedVector{S}}) = S
4175
@inline (::Type{SizedVector{S}}){S,T,M}(a::Array{T,M}) = SizedArray{S,T,1,M}(a)
76+
@inline (::Type{SizedVector{S}}){S,T,L}(x::NTuple{L,T}) = SizedArray{S,T,1,1}(x)
77+
@inline (::Type{Vector})(sa::SizedVector) = sa.data
78+
@inline convert(::Type{Vector}, sa::SizedVector) = sa.data
4279

4380
typealias SizedMatrix{S,T,M} SizedArray{S,T,2,M}
81+
@pure size{S}(::Type{SizedMatrix{S}}) = S
4482
@inline (::Type{SizedMatrix{S}}){S,T,M}(a::Array{T,M}) = SizedArray{S,T,2,M}(a)
83+
@inline (::Type{SizedMatrix{S}}){S,T,L}(x::NTuple{L,T}) = SizedArray{S,T,2,2}(x)
84+
@inline (::Type{Matrix})(sa::SizedMatrix) = sa.data
85+
@inline convert(::Type{Matrix}, sa::SizedMatrix) = sa.data
4586

4687

4788
"""

src/abstractarray.jl

Lines changed: 80 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ end
4141
@pure function similar_type{SA<:StaticArray,T}(::Union{SA,Type{SA}}, ::Type{T}, size::Int)
4242
similar_type(similar_type(SA, T), size)
4343
end
44+
@pure function similar_type{SA<:StaticArray,T,S}(::Union{SA,Type{SA}}, ::Type{T}, size::Size{S})
45+
similar_type(similar_type(SA, T), size)
46+
end
4447
@generated function similar_type{SA<:StaticArray,T}(::Union{SA,Type{SA}}, ::Type{T})
4548
# This function has a strange error (on tests) regarding double-inference, if it is marked @pure
4649
if T == eltype(SA)
@@ -115,6 +118,25 @@ end
115118

116119
@pure similar_type{SA<:StaticArray,N}(::Union{SA,Type{SA}}, sizes::Tuple{Vararg{Int,N}}) = SArray{sizes, eltype(SA), N, prod(sizes)}
117120

121+
@generated function similar_type{SA <: StaticArray,S}(::Union{SA,Type{SA}}, ::Size{S})
122+
if length(S) == 1
123+
return quote
124+
$(Expr(:meta, :inline))
125+
SVector{$(S[1]), $(eltype(SA))}
126+
end
127+
elseif length(S) == 2
128+
return quote
129+
$(Expr(:meta, :inline))
130+
SMatrix{$(S[1]), $(S[2]), $(eltype(SA))}
131+
end
132+
else
133+
return quote
134+
$(Expr(:meta, :inline))
135+
SArray{S, $(eltype(SA)), $(length(S)), $(prod(S))}
136+
end
137+
end
138+
end
139+
118140
# Some specializations for the mutable case
119141
@pure similar_type{MA<:Union{MVector,MMatrix,MArray,SizedArray}}(::Union{MA,Type{MA}}, size::Int) = MVector{size, eltype(MA)}
120142
@pure similar_type{MA<:Union{MVector,MMatrix,MArray,SizedArray}}(::Union{MA,Type{MA}}, sizes::Tuple{Int}) = MVector{sizes[1], eltype(MA)}
@@ -123,23 +145,73 @@ end
123145

124146
@pure similar_type{MA<:Union{MVector,MMatrix,MArray,SizedArray},N}(::Union{MA,Type{MA}}, sizes::Tuple{Vararg{Int,N}}) = MArray{sizes, eltype(MA), N, prod(sizes)}
125147

148+
@generated function similar_type{MA<:Union{MVector,MMatrix,MArray,SizedArray},S}(::Union{MA,Type{MA}}, ::Size{S})
149+
if length(S) == 1
150+
return quote
151+
$(Expr(:meta, :inline))
152+
MVector{$(S[1]), $(eltype(MA))}
153+
end
154+
elseif length(S) == 2
155+
return quote
156+
$(Expr(:meta, :inline))
157+
MMatrix{$(S[1]), $(S[2]), $(eltype(MA))}
158+
end
159+
else
160+
return quote
161+
$(Expr(:meta, :inline))
162+
MArray{S, $(eltype(MA)), $(length(S)), $(prod(S))}
163+
end
164+
end
165+
end
166+
126167
# And also similar() returning mutable StaticArrays
127168
@inline similar{SV <: StaticVector}(::SV) = MVector{length(SV),eltype(SV)}()
128169
@inline similar{SV <: StaticVector, T}(::SV, ::Type{T}) = MVector{length(SV),T}()
129-
@inline similar{SA <: StaticArray}(::SA, sizes::Tuple{Int}) = MVector{sizes[1], eltype(SA)}()
130-
@inline similar{SA <: StaticArray}(::SA, size::Int) = MVector{size, eltype(SA)}()
131-
@inline similar{T}(::StaticArray, ::Type{T}, sizes::Tuple{Int}) = MVector{sizes[1],T}()
132-
@inline similar{T}(::StaticArray, ::Type{T}, size::Int) = MVector{size,T}()
133170

134171
@inline similar{SM <: StaticMatrix}(m::SM) = MMatrix{size(SM,1),size(SM,2),eltype(SM),length(SM)}()
135172
@inline similar{SM <: StaticMatrix, T}(::SM, ::Type{T}) = MMatrix{size(SM,1),size(SM,2),T,length(SM)}()
136-
@inline similar{SA <: StaticArray}(::SA, sizes::Tuple{Int,Int}) = MMatrix{sizes[1], sizes[2], eltype(SA), sizes[1]*sizes[2]}()
137-
@inline similar(a::StaticArray, T::Type, sizes::Tuple{Int,Int}) = MMatrix{sizes[1], sizes[2], T, sizes[1]*sizes[2]}()
138173

139174
@inline similar{SA <: StaticArray}(m::SA) = MArray{size(SA),eltype(SA),ndims(SA),length(SA)}()
140175
@inline similar{SA <: StaticArray,T}(m::SA, ::Type{T}) = MArray{size(SA),T,ndims(SA),length(SA)}()
141-
@inline similar{SA <: StaticArray,N}(m::SA, sizes::NTuple{N, Int}) = MArray{sizes,eltype(SA),N,prod(sizes)}()
142-
@inline similar{SA <: StaticArray,N,T}(m::SA, ::Type{T}, sizes::NTuple{N, Int}) = MArray{sizes,T,N,prod(sizes)}()
176+
177+
@generated function similar{SA <: StaticArray,S}(::SA, ::Size{S})
178+
if length(S) == 1
179+
return quote
180+
$(Expr(:meta, :inline))
181+
MVector{$(S[1]), $(eltype(SA))}()
182+
end
183+
elseif length(S) == 2
184+
return quote
185+
$(Expr(:meta, :inline))
186+
MMatrix{$(S[1]), $(S[2]), $(eltype(SA))}()
187+
end
188+
else
189+
return quote
190+
$(Expr(:meta, :inline))
191+
MArray{S, $(eltype(SA))}()
192+
end
193+
end
194+
end
195+
196+
@generated function similar{SA <: StaticArray, T, S}(::SA, ::Type{T}, ::Size{S})
197+
if length(S) == 1
198+
return quote
199+
$(Expr(:meta, :inline))
200+
MVector{$(S[1]), T}()
201+
end
202+
elseif length(S) == 2
203+
return quote
204+
$(Expr(:meta, :inline))
205+
MMatrix{$(S[1]), $(S[2]), T}()
206+
end
207+
else
208+
return quote
209+
$(Expr(:meta, :inline))
210+
MArray{S, T}()
211+
end
212+
end
213+
end
214+
143215

144216
# This is used in Base.LinAlg quite a lot, and it impacts type stability
145217
# since some functions like expm() branch on a check for Hermitian or Symmetric

src/cholesky.jl

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
11
# Generic Cholesky decomposition for fixed-size matrices, mostly unrolled
2+
@inline function Base.chol(A::StaticMatrix)
3+
ishermitian(A) || Base.LinAlg.non_hermitian_error("chol")
4+
_chol(Size(A), A)
5+
end
26

3-
# Currently all sanity checks are disabled!
4-
@generated function Base.chol(A::StaticMatrix)
5-
if size(A) === (1, 1)
6-
return :(_chol1(A))
7-
elseif size(A) === (2, 2)
8-
#ishermitian(A) || Base.LinAlg.non_hermitian_error("chol")
9-
return :(_chol2(A))
10-
elseif size(A) === (3, 3)
11-
#ishermitian(A) || Base.LinAlg.non_hermitian_error("chol")
12-
return :(_chol3(A))
13-
else
14-
return :(chol(Array(A)))
15-
end
7+
@inline function Base.chol{T<:Real, SM <: StaticMatrix}(A::Base.LinAlg.RealHermSymComplexHerm{T,SM})
8+
ishermitian(A) || Base.LinAlg.non_hermitian_error("chol")
9+
_chol(Size(A), A)
10+
end
11+
12+
@inline function Base.chol{SM<:StaticMatrix}(A::Symmetric{SM})
13+
eltype(A) <: Real && (ishermitian(A) || Base.LinAlg.non_hermitian_error("chol"))
14+
_chol(Size(A), A)
1615
end
1716

18-
@generated function _chol1(A::StaticMatrix)
17+
@generated function _chol(::Size{(1,1)}, A::StaticMatrix)
1918
@assert size(A) == (1,1)
2019
T = promote_type(typeof(sqrt(one(eltype(A)))), Float32)
2120
newtype = similar_type(A,T)
@@ -26,8 +25,7 @@ end
2625
end
2726
end
2827

29-
30-
@generated function _chol2(A::StaticMatrix)
28+
@generated function _chol(::Size{(2,2)}, A::StaticMatrix)
3129
@assert size(A) == (2,2)
3230
T = promote_type(typeof(sqrt(one(eltype(A)))), Float32)
3331
newtype = similar_type(A,T)
@@ -41,7 +39,7 @@ end
4139
end
4240
end
4341

44-
@generated function _chol3(A::StaticMatrix)
42+
@generated function _chol(::Size{(3,3)}, A::StaticMatrix)
4543
@assert size(A) == (3,3)
4644
T = promote_type(typeof(sqrt(one(eltype(A)))), Float32)
4745
newtype = similar_type(A,T)
@@ -57,3 +55,6 @@ end
5755
($newtype)((a11, $(zero(T)), $(zero(T)), a12, a22, $(zero(T)), a13, a23, a33))
5856
end
5957
end
58+
59+
# Otherwise default algorithm returning wrapped SizedArray
60+
@inline _chol(s::Size, A::StaticArray) = s(full(chol(Hermitian(Array(A)))))

0 commit comments

Comments
 (0)