Skip to content

Commit 9aeb0d1

Browse files
author
Andy Ferris
committed
Finish eig()
* We currently support Hermitian eigenvalue decomposition, even for those not wrapped in `Hermition{}` * Non-hermitian static matrices now error on eig() (there is a run-time test) * Some matrix_multiply `@inbounds`
1 parent 4700bbf commit 9aeb0d1

File tree

5 files changed

+93
-114
lines changed

5 files changed

+93
-114
lines changed

src/eigen.jl

Lines changed: 6 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,26 @@
33
end
44

55
@inline function eig{T, SM <: StaticMatrix}(A::Base.LinAlg.HermOrSym{T,SM}; permute::Bool=true, scale::Bool=true)
6-
_eig(Size(A), A, permute, scale)
6+
_eig(Size(SM), A, permute, scale)
77
end
88

99
@inline function _eig(s::Size, A::StaticMatrix, permute, scale)
10-
# TODO This is not type stable: both this fast branch, and the implementation of `Base.eigfact`.
11-
# Decision needed in how to proceed. See e.g. JuliaLang/julia#12304
10+
# Only cover the hermitian branch, for now ast least
11+
# This also solves some type-stability issues such as arise in Base
1212
if ishermitian(A)
1313
return _eig(s, Hermitian(A), permute, scale)
14+
else
15+
error("Only hermitian matrices are diagonalizable by *StaticArrays*. Non-Hermitian matrices should be converted to `Array` first.")
1416
end
15-
eigen = eigfact(Array(A); permute=permute, scale=scale)
16-
return (Size(Size(typeof(A))[1])(eigen.values), s(eigen.vectors)) # Return a SizedArray
1717
end
1818

19-
2019
@inline function _eig{T<:Real}(s::Size, A::Base.LinAlg.RealHermSymComplexHerm{T}, permute, scale)
2120
eigen = eigfact(Hermitian(Array(parent(A))); permute=permute, scale=scale)
2221
return (s(eigen.values), s(eigen.vectors)) # Return a SizedArray
2322
end
2423

2524

26-
@inline function _eig(::Size{(1,1)}, A, permute, scale)
25+
@inline function _eig{T<:Real}(::Size{(1,1)}, A::Base.LinAlg.RealHermSymComplexHerm{T}, permute, scale)
2726
@inbounds return (SVector{1,T}((A[1],)), eye(SMatrix{1,1,T}))
2827
end
2928

@@ -148,66 +147,3 @@ end
148147
end
149148
end
150149
end
151-
152-
153-
154-
155-
# TODO: the non-symmetric case: type stable version (real -> real) since it is more useful to us!
156-
157-
#=
158-
@generated function eig{T<:Real, SM <: StaticMatrix}(A::Base.LinAlg.RealHermSymComplexHerm{T,SM}; permute::Bool=true, scale::Bool=true)
159-
if size(SM) == (1,1)
160-
return quote
161-
$(Expr(:meta, :inline))
162-
@inbounds return (SVector{1,T}((A[1],)), eye(SMatrix{1,1,T}))
163-
end
164-
elseif size(SM) == (2,2)
165-
return quote
166-
$(Expr(:meta, :inline))
167-
a = A.data
168-
169-
if m2.uplo == 'U'
170-
@inbounds t_half = real(A[1] + A[4])/2
171-
@inbounds d = real(A[1]*A[4] - A[3]'*A[3]) # Should be real
172-
173-
tmp2 = t_half*t_half - d
174-
tmp2 < 0 ? tmp = zero(tmp2) : tmp = sqrt(tmp2) # Numerically stable for identity matrices, etc.
175-
vals = SVector(t_half - tmp, t_half + tmp)
176-
f
177-
@inbounds if A[3] == 0
178-
@inbounds if A[3] == 0
179-
vecs = eye(SMatrix{2,2,T})
180-
else
181-
@inbounds vecs = @SMatrix [ A[3] A[3] ;
182-
vals[1]-A[1] vals[2]-A[1] ]
183-
end
184-
else
185-
@inbounds v11 = vals[1]-A[4]
186-
@inbounds n1 = sqrt(v11'*v11 + A[2]'*A[2])
187-
v11 = v11 / n1
188-
@inbounds v12 = A[2] / n1
189-
190-
@inbounds v21 = vals[2]-A[4]
191-
@inbounds n2 = sqrt(v21'*v21 + A[2]'*A[2])
192-
v21 = v21 / n2
193-
@inbounds v22 = A[2] / n2
194-
195-
vecs = @SMatrix [ v11 v21 ;
196-
v12 v22 ]
197-
end
198-
return (vals,vecs)
199-
else
200-
201-
end
202-
end
203-
elseif size(SM) == (3,3)
204-
error("not implemented")
205-
else
206-
return quote
207-
$(Expr(:meta, :inline))
208-
eigen = eigfact(A)
209-
return (eigen.values, eigen.vectors)
210-
end
211-
end
212-
end
213-
=#

src/matrix_multiply.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ end
489489
@inbounds return $(Expr(:call, newtype, Expr(:tuple, exprs...)))
490490
end
491491
end
492-
7
492+
493493

494494
# TODO aliasing problems if c === b?
495495
@generated function A_mul_B!{T1,T2,T3}(c::StaticVector{T1}, A::StaticMatrix{T2}, b::StaticVector{T3})
@@ -767,11 +767,11 @@ end
767767

768768
exprs = [:(C[$(sub2ind(s, k1, k2))] = $(Symbol("tmp_$k2"))[$k1]) for k1 = 1:sA[1], k2 = 1:sB[2]]
769769

770-
return Expr(:block,
771-
Expr(:meta,:inline),
772-
vect_exprs...,
773-
exprs...
774-
)
770+
return quote
771+
Expr(:meta,:inline)
772+
@inbounds $(Expr(:block, vect_exprs...))
773+
@inbounds $(Expr(:block, exprs...))
774+
end
775775
end
776776

777777
#function A_mul_B_blas(a, b, c, A, B)

test/eigen.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
@testset "Eigenvalue decomposition" begin
2+
@testset "1×1" begin
3+
m = @SMatrix [2.0]
4+
(vals, vecs) = eig(m)
5+
@test vals === SVector(2.0)
6+
@test vecs === SMatrix{1,1}(1.0)
7+
8+
(vals, vecs) = eig(Symmetric(m))
9+
@test vals === SVector(2.0)
10+
@test vecs === SMatrix{1,1}(1.0)
11+
end
12+
13+
@testset "2×2" for i = 1:100
14+
m_a = randn(2,2)
15+
m_a = m_a*m_a'
16+
m = SMatrix{2,2}(m_a)
17+
18+
(vals_a, vecs_a) = eig(m)
19+
(vals, vecs) = eig(m)
20+
@test vals::SVector vals_a
21+
@test (vecs*diagm(vals)*vecs')::SMatrix m
22+
23+
(vals, vecs) = eig(Symmetric(m))
24+
@test vals::SVector vals_a
25+
@test (vecs*diagm(vals)*vecs')::SMatrix m
26+
end
27+
28+
@testset "3×3" for i = 1:100
29+
m_a = randn(3,3)
30+
m_a = m_a*m_a'
31+
m = SMatrix{3,3}(m_a)
32+
33+
(vals_a, vecs_a) = eig(m)
34+
(vals, vecs) = eig(m)
35+
@test vals::SVector vals_a
36+
@test (vecs*diagm(vals)*vecs')::SMatrix m
37+
38+
(vals, vecs) = eig(Symmetric(m))
39+
@test vals::SVector vals_a
40+
@test (vecs*diagm(vals)*vecs')::SMatrix m
41+
end
42+
end

test/matrix_multiply.jl

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@testset "Matrix multiplication" begin
1+
@testset "Matrix multiplication" begin
22
@testset "Matrix-vector" begin
33
m = @SMatrix [1 2; 3 4]
44
v = @SVector [1, 2]
@@ -163,46 +163,46 @@
163163
@test a2::MArray{(2,2),Int,2,4} == @MArray [10 13; 22 29]
164164

165165
# Alternative builtin method used for n > 8
166-
m_array = rand(1:10, 10, 10)
167-
n_array = rand(1:10, 10, 10)
168-
a_array = m_array*n_array
166+
m_array_2 = rand(1:10, 10, 10)
167+
n_array_2 = rand(1:10, 10, 10)
168+
a_array_2 = m_array_2*n_array_2
169169

170-
m = MMatrix{10,10}(m_array)
171-
n = MMatrix{10,10}(n_array)
172-
a = MMatrix{10,10,Int}()
173-
A_mul_B!(a, m, n)
174-
@test a == a_array
170+
m_2 = MMatrix{10,10}(m_array_2)
171+
n_2 = MMatrix{10,10}(n_array_2)
172+
a_2 = MMatrix{10,10,Int}()
173+
A_mul_B!(a_2, m_2, n_2)
174+
@test a_2 == a_array_2
175175

176176
# BLAS used for n > 14
177-
m_array = randn(4, 4)
178-
n_array = randn(4, 4)
179-
a_array = m_array*n_array
180-
181-
m = MMatrix{4,4}(m_array)
182-
n = MMatrix{4,4}(n_array)
183-
a = MMatrix{4,4,Float64}()
184-
A_mul_B!(a, m, n)
185-
@test a a_array
186-
187-
m_array = randn(10, 10)
188-
n_array = randn(10, 10)
189-
a_array = m_array*n_array
190-
191-
m = MMatrix{10,10}(m_array)
192-
n = MMatrix{10,10}(n_array)
193-
a = MMatrix{10,10,Float64}()
194-
A_mul_B!(a, m, n)
195-
@test a a_array
196-
197-
m_array = rand(1:10, 16, 16)
198-
n_array = rand(1:10, 16, 16)
199-
a_array = m_array*n_array
200-
201-
m = MMatrix{16,16}(m_array)
202-
n = MMatrix{16,16}(n_array)
203-
a = MMatrix{16,16,Int}()
204-
A_mul_B!(a, m, n)
205-
@test a a_array
177+
m_array_3 = randn(4, 4)
178+
n_array_3 = randn(4, 4)
179+
a_array_3 = m_array_3*n_array_3
180+
181+
m_3 = MMatrix{4,4}(m_array_3)
182+
n_3 = MMatrix{4,4}(n_array_3)
183+
a_3 = MMatrix{4,4,Float64}()
184+
A_mul_B!(a_3, m_3, n_3)
185+
@test a_3 a_array_3
186+
187+
m_array_4 = randn(10, 10)
188+
n_array_4 = randn(10, 10)
189+
a_array_4 = m_array_4*n_array_4
190+
191+
m_4 = MMatrix{10,10}(m_array_4)
192+
n_4 = MMatrix{10,10}(n_array_4)
193+
a_4 = MMatrix{10,10,Float64}()
194+
A_mul_B!(a_4, m_4, n_4)
195+
@test a_4 a_array_4
196+
197+
m_array_5 = rand(1:10, 16, 16)
198+
n_array_5 = rand(1:10, 16, 16)
199+
a_array_5 = m_array_5*n_array_5
200+
201+
m_5 = MMatrix{16,16}(m_array_5)
202+
n_5 = MMatrix{16,16}(n_array_5)
203+
a_5 = MMatrix{16,16,Int}()
204+
A_mul_B!(a_5, m_5, n_5)
205+
@test a_5 a_array_5
206206

207207
# Float64
208208
vf = @SVector [2.0, 4.0]

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,6 @@ using Base.Test
1919
include("linalg.jl")
2020
include("matrix_multiply.jl")
2121
include("solve.jl")
22+
include("eigen.jl")
2223
include("deque.jl")
2324
end

0 commit comments

Comments
 (0)