Skip to content

Commit 6301885

Browse files
committed
Merge branch 'master' into pull-request/a0d168f3
# Conflicts: # src/det.jl # test/det.jl
2 parents 02b4890 + 79ac5f1 commit 6301885

File tree

12 files changed

+135
-70
lines changed

12 files changed

+135
-70
lines changed

REQUIRE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
julia 0.6.0-pre
1+
julia 0.6.0

src/StaticArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import Base: getindex, setindex!, size, similar, vec, show,
88
length, convert, promote_op, promote_rule, map, map!, reduce, reducedim, mapreducedim,
99
mapreduce, broadcast, broadcast!, conj, transpose, ctranspose,
1010
hcat, vcat, ones, zeros, eye, one, cross, vecdot, reshape, fill,
11-
fill!, det, logdet, inv, eig, eigvals, expm, logm, sqrtm, lyap, trace, diag, vecnorm, norm, dot, diagm, diag,
11+
fill!, det, logdet, inv, eig, eigvals, expm, logm, sqrtm, lyap, trace, kron, diag, vecnorm, norm, dot, diagm, diag,
1212
lu, svd, svdvals, svdfact, factorize, ishermitian, issymmetric, isposdef,
1313
sum, diff, prod, count, any, all, minimum,
1414
maximum, extrema, mean, copy, rand, randn, randexp, rand!, randn!,

src/det.jl

Lines changed: 14 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,26 @@
1-
@inline det(A::StaticMatrix) = _det(Size(A), A)
2-
@inline logdet(A::StaticMatrix) = _logdet(Size(A), A)
1+
@inline function det(A::StaticMatrix)
2+
T = eltype(A)
3+
S = typeof((one(T)*zero(T) + zero(T))/one(T))
4+
_det(Size(A),A,S)
5+
end
36

4-
@inline _det(::Size{(1,1)}, A::StaticMatrix) = @inbounds return A[1]
7+
@inline logdet(A::StaticMatrix) = log(det(A))
58

6-
@inline function _det(::Size{(2,2)}, A::StaticMatrix)
7-
@inbounds return A[1]*A[4] - A[3]*A[2]
8-
end
9+
@inline _det(::Size{(1,1)}, A::StaticMatrix,S::Type) = @inbounds return convert(S,A[1])
910

10-
@inline function _det(::Size{(2,2)}, A::StaticMatrix{<:Any, <:Any, <:Unsigned})
11-
@inbounds return Signed(A[1]*A[4]) - Signed(A[3]*A[2])
11+
@inline function _det(::Size{(2,2)}, A::StaticMatrix, S::Type)
12+
A = similar_type(A,S)(A)
13+
@inbounds return A[1]*A[4] - A[3]*A[2]
1214
end
1315

14-
@inline function _det(::Size{(3,3)}, A::StaticMatrix)
16+
@inline function _det(::Size{(3,3)}, A::StaticMatrix, S::Type)
17+
A = similar_type(A,S)(A)
1518
@inbounds x0 = SVector(A[1], A[2], A[3])
1619
@inbounds x1 = SVector(A[4], A[5], A[6])
1720
@inbounds x2 = SVector(A[7], A[8], A[9])
1821
return vecdot(x0, cross(x1, x2))
1922
end
2023

21-
@inline function _det(::Size{(3,3)}, A::StaticMatrix{<:Any, <:Any, <:Unsigned})
22-
@inbounds x0 = SVector(Signed(A[1]), Signed(A[2]), Signed(A[3]))
23-
@inbounds x1 = SVector(Signed(A[4]), Signed(A[5]), Signed(A[6]))
24-
@inbounds x2 = SVector(Signed(A[7]), Signed(A[8]), Signed(A[9]))
25-
return vecdot(x0, cross(x1, x2))
26-
end
27-
28-
@inline function _det(::Size{(4,4)}, A::StaticMatrix)
29-
@inbounds return (
30-
A[13] * A[10] * A[7] * A[4] - A[9] * A[14] * A[7] * A[4] -
31-
A[13] * A[6] * A[11] * A[4] + A[5] * A[14] * A[11] * A[4] +
32-
A[9] * A[6] * A[15] * A[4] - A[5] * A[10] * A[15] * A[4] -
33-
A[13] * A[10] * A[3] * A[8] + A[9] * A[14] * A[3] * A[8] +
34-
A[13] * A[2] * A[11] * A[8] - A[1] * A[14] * A[11] * A[8] -
35-
A[9] * A[2] * A[15] * A[8] + A[1] * A[10] * A[15] * A[8] +
36-
A[13] * A[6] * A[3] * A[12] - A[5] * A[14] * A[3] * A[12] -
37-
A[13] * A[2] * A[7] * A[12] + A[1] * A[14] * A[7] * A[12] +
38-
A[5] * A[2] * A[15] * A[12] - A[1] * A[6] * A[15] * A[12] -
39-
A[9] * A[6] * A[3] * A[16] + A[5] * A[10] * A[3] * A[16] +
40-
A[9] * A[2] * A[7] * A[16] - A[1] * A[10] * A[7] * A[16] -
41-
A[5] * A[2] * A[11] * A[16] + A[1] * A[6] * A[11] * A[16])
42-
end
43-
44-
@inline _logdet(S::Union{Size{(1,1)},Size{(2,2)},Size{(3,3)}}, A::StaticMatrix) = log(_det(S, A))
45-
46-
for (symb, f) in [(:_det, :det), (:_logdet, :logdet)]
47-
eval(quote
48-
@generated function $symb{S}(::Size{S}, A::StaticMatrix)
49-
if S[1] != S[2]
50-
throw(DimensionMismatch("matrix is not square"))
51-
end
52-
return quote # Implementation from Base
53-
@_inline_meta
54-
T = eltype(A)
55-
T2 = typeof((one(T)*zero(T) + zero(T))/one(T))
56-
if istriu(A) || istril(A)
57-
return convert(T2, $($f)(UpperTriangular(A))) # Is this a Julia bug that a convert is not type stable??
58-
end
59-
AA = convert(Array{T2}, A)
60-
return $($f)(lufact(AA))
61-
end
62-
end
63-
end)
24+
@inline function _det(::Size, A::StaticMatrix,::Type)
25+
return det(Matrix(A))
6426
end

src/inv.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ end
1616
S = typeof((one(T)*zero(T) + zero(T))/one(T))
1717
newtype = similar_type(A, S)
1818

19-
@inbounds x0 = SVector(A[1], A[2], A[3])
20-
@inbounds x1 = SVector(A[4], A[5], A[6])
21-
@inbounds x2 = SVector(A[7], A[8], A[9])
19+
@inbounds x0 = SVector{3,S}(A[1], A[2], A[3])
20+
@inbounds x1 = SVector{3,S}(A[4], A[5], A[6])
21+
@inbounds x2 = SVector{3,S}(A[7], A[8], A[9])
2222

2323
y0 = cross(x1,x2)
2424
d = vecdot(x0, y0)

src/linalg.jl

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -172,11 +172,18 @@ end
172172
end
173173
end
174174

175-
@inline diagm(v::StaticVector) = _diagm(Size(v), v)
176-
@generated function _diagm(::Size{S}, v::StaticVector) where {S}
177-
Snew = (S[1], S[1])
175+
@inline diagm(v::StaticVector, k::Type{Val{D}}=Val{0}) where {D} = _diagm(Size(v), v, k)
176+
@generated function _diagm(::Size{S}, v::StaticVector, ::Type{Val{D}}) where {S,D}
177+
S1 = S[1]
178+
Snew1 = S1+abs(D)
179+
Snew = (Snew1, Snew1)
180+
Lnew = Snew1 * Snew1
178181
T = eltype(v)
179-
exprs = [i == j ? :(v[$i]) : zero(T) for i = 1:S[1], j = 1:S[1]]
182+
ind = diagind(Snew1, Snew1, D)
183+
exprs = fill(:(zero($T)), Lnew)
184+
for n = 1:S[1]
185+
exprs[ind[n]] = :(v[$n])
186+
end
180187
return quote
181188
$(Expr(:meta, :inline))
182189
@inbounds return similar_type($v, Size($Snew))(tuple($(exprs...)))
@@ -310,6 +317,21 @@ end
310317
end
311318
end
312319

320+
const _length_limit = Length(200)
321+
322+
@inline kron(a::StaticMatrix, b::StaticMatrix) = _kron(_length_limit, Size(a), Size(b), a, b)
323+
@generated function _kron(::Length{length_limit}, ::Size{SA}, ::Size{SB}, a, b) where {length_limit,SA,SB}
324+
outsize = SA .* SB
325+
if prod(outsize) > length_limit
326+
return :( SizedMatrix{$(outsize[1]),$(outsize[2])}( kron(drop_sdims(a), drop_sdims(b)) ) )
327+
end
328+
rows = [:(hcat($([:(a[$(sub2ind(SA,i,j))]*b) for j=1:SA[2]]...))) for i=1:SA[1]]
329+
return quote
330+
@_inline_meta
331+
@inbounds return vcat($(rows...))
332+
end
333+
end
334+
313335
@inline Size(::Union{RowVector{T, SA}, Type{RowVector{T, SA}}}) where {T, SA <: StaticArray} = Size(1, Size(SA)[1])
314336
@inline Size(::Union{RowVector{T, CA}, Type{RowVector{T, CA}}} where CA <: ConjVector{<:Any, SA}) where {T, SA <: StaticArray} = Size(1, Size(SA)[1])
315337
@inline Size(::Union{Symmetric{T,SA}, Type{Symmetric{T,SA}}}) where {T,SA<:StaticArray} = Size(SA)
@@ -323,4 +345,3 @@ end
323345

324346
@inline Base.LinAlg.Symmetric(A::StaticMatrix, uplo::Char='U') = (Base.LinAlg.checksquare(A);Symmetric{eltype(A),typeof(A)}(A, uplo))
325347
@inline Base.LinAlg.Hermitian(A::StaticMatrix, uplo::Char='U') = (Base.LinAlg.checksquare(A);Hermitian{eltype(A),typeof(A)}(A, uplo))
326-

src/util.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,34 @@ end
5959

6060
return nothing
6161
end
62+
63+
64+
# Trivial view used to drop static dimensions to override dispatch
65+
struct TrivialView{A,T,N} <: AbstractArray{T,N}
66+
a::A
67+
end
68+
69+
size(a::TrivialView) = size(a.a)
70+
getindex(a::TrivialView, inds...) = getindex(a.a, inds...)
71+
setindex!(a::TrivialView, inds...) = setindex!(a.a, inds...)
72+
Base.IndexStyle(::Type{<:TrivialView{A}}) where {A} = IndexStyle(A)
73+
74+
TrivialView(a::AbstractArray{T,N}) where {T,N} = TrivialView{typeof(a),T,N}(a)
75+
76+
77+
# Remove the static dimensions from an array
78+
79+
"""
80+
drop_sdims(a)
81+
82+
Return an `AbstractArray` with the same elements as `a`, but with static
83+
dimensions removed (ie, not a `StaticArray`).
84+
85+
This is useful if you want to override dispatch to call the `Base` version of
86+
operations such as `kron` instead of the implementation in `StaticArrays`.
87+
Normally you shouldn't need to do this, but it can be more efficient for
88+
certain algorithms where the number of elements of the output is a lot larger
89+
than the input.
90+
"""
91+
@inline drop_sdims(a::StaticArray) = TrivialView(a)
92+
@inline drop_sdims(a) = a

test/SizedArray.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@
6868
@test convert(Vector, SizedArray{Tuple{4}, Int, 1}(collect(3:6))) == collect(3:6)
6969
@test Matrix(SMatrix{2,2}((1,2,3,4))) == [1 3; 2 4]
7070
@test convert(Matrix, SMatrix{2,2}((1,2,3,4))) == [1 3; 2 4]
71+
# Conversion after reshaping
72+
@test_broken Array(SizedMatrix{2,2}([1,2,3,4])) == [1 3; 2 4]
7173
end
7274

7375
@testset "promotion" begin

test/det.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
11
@testset "Determinant" begin
2-
@test det(@SMatrix [1]) === 1
3-
@test logdet(@SMatrix [1]) === 0.0
4-
@test det(@SMatrix [0 1; 1 0]) === -1
2+
@test det(@SMatrix [1]) == 1
3+
@test logdet(@SMatrix [1]) == 0.0
4+
@test det(@SMatrix [0 1; 1 0]) == -1
55
@test logdet(@SMatrix Complex{Float64}[0 1; 1 0]) == log(det(@SMatrix Complex{Float64}[0 1; 1 0]))
66

7-
@test det(@SMatrix [0 1 0; 1 0 0; 0 0 1]) === -1
7+
@test det(@SMatrix [0 1 0; 1 0 0; 0 0 1]) == -1
88
m = randn(Float64, 4,4)
99
@test det(SMatrix{4,4}(m)) det(m)
1010
#triu/tril
11-
@test det(@SMatrix [1 2; 0 3]) === 3
11+
@test det(@SMatrix [1 2; 0 3]) == 3
1212
@test det(@SMatrix [1 2 3 4; 0 5 6 7; 0 0 8 9; 0 0 0 10]) == 400.0
1313
@test logdet(@SMatrix [1 2 3 4; 0 5 6 7; 0 0 8 9; 0 0 0 10]) log(400.0)
1414
@test @inferred(det(ones(SMatrix{10,10,Complex{Float64}}))) == 0
1515

16-
# Unsigned specializations
17-
@test det(@SMatrix [0x00 0x01; 0x01 0x00])::Int8 == -1
18-
@test det(@SMatrix [0x00 0x01 0x00; 0x01 0x00 0x00; 0x00 0x00 0x01])::Int8 == -1
16+
# Unsigned specializations , compare to Base
17+
M = @SMatrix [1 2 3 4; 200 5 6 7; 0 0 8 9; 0 0 0 10]
18+
for sz in (2,3,4), typ in (UInt8,UInt16,UInt32,UInt64)
19+
Mtag = SMatrix{sz,sz,typ}(M[1:sz,1:sz])
20+
@test det(Mtag) == det(Array(Mtag))
21+
end
1922

2023
@test_throws DimensionMismatch det(@SMatrix [0; 1])
2124
end

test/inv.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,14 @@ end
7070
@test norm(Matrix(sm*inv(sm) - eye(4))) < 12*norm(m*inv(m) - eye(4))
7171
end
7272

73+
@testset "Matrix inverse 5x5" begin
74+
m = randn(Float64, 5,5) + eye(5)
75+
@test inv(SMatrix{5,5}(m))::StaticMatrix inv(m)
76+
m = triu(randn(Float64, 5,5) + eye(5))
77+
@test inv(SMatrix{5,5}(m))::StaticMatrix inv(m)
78+
m = tril(randn(Float64, 5,5) + eye(5))
79+
@test inv(SMatrix{5,5}(m))::StaticMatrix inv(m)
80+
end
7381

7482
#-------------------------------------------------------------------------------
7583
# More comprehensive but qualitiative testing for inv() accuracy

test/linalg.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ using StaticArrays, Base.Test
4040

4141
@testset "diagm()" begin
4242
@test @inferred(diagm(SVector(1,2))) === @SMatrix [1 0; 0 2]
43+
@test @inferred(diagm(SVector(1,2,3), Val{2}))::SMatrix == diagm([1,2,3], 2)
44+
@test @inferred(diagm(SVector(1,2,3), Val{-2}))::SMatrix == diagm([1,2,3], -2)
4345
end
4446

4547
@testset "diag()" begin
@@ -147,4 +149,19 @@ using StaticArrays, Base.Test
147149
@test vecnorm(SVector{0, Float64}(()), 1) === 0.
148150
@test trace(SMatrix{0,0,Float64}(())) === 0.
149151
end
152+
153+
@testset "kron" begin
154+
@test @inferred(kron(@SMatrix([1 2; 3 4]), @SMatrix([0 1 0; 1 0 1]))) ==
155+
SMatrix{4,6,Int}([0 1 0 0 2 0;
156+
1 0 1 2 0 2;
157+
0 3 0 0 4 0;
158+
3 0 3 4 0 4])
159+
@test @inferred(kron(@SMatrix([1 2; 3 4]), @SMatrix([2.0]))) === @SMatrix [2.0 4.0; 6.0 8.0]
160+
161+
# Output should be heap allocated into a SizedArray when it gets large
162+
# enough.
163+
M1 = collect(1:20)
164+
M2 = collect(20:-1:1).'
165+
@test @inferred(kron(SMatrix{20,1}(M1),SMatrix{1,20}(M2)))::SizedMatrix{20,20} == kron(M1,M2)
166+
end
150167
end

0 commit comments

Comments
 (0)