Skip to content

Commit 261939b

Browse files
author
Andy Ferris
committed
Added eigvals, tests
1 parent 13fc61f commit 261939b

File tree

3 files changed

+164
-26
lines changed

3 files changed

+164
-26
lines changed

src/StaticArrays.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@ import Base: @pure, @propagate_inbounds, getindex, setindex!, size, similar,
66
length, convert, promote_op, map, map!, reduce, mapreduce,
77
broadcast, broadcast!, conj, transpose, ctranspose, hcat, vcat,
88
ones, zeros, eye, one, cross, vecdot, reshape, fill, fill!, det,
9-
inv, eig, trace, vecnorm, norm, dot, diagm, sum, prod, count, any,
10-
all, sumabs, sumabs2, minimum, maximum, extrema, mean, copy
9+
inv, eig, eigvals, trace, vecnorm, norm, dot, diagm, sum, prod,
10+
count, any, all, sumabs, sumabs2, minimum, maximum, extrema, mean,
11+
copy
1112

1213
export StaticScalar, StaticArray, StaticVector, StaticMatrix
1314
export Scalar, SArray, SVector, SMatrix

src/eigen.jl

Lines changed: 127 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,107 @@
1+
2+
@inline eigvals{T<:Real,SA<:StaticArray}(a::Base.LinAlg.RealHermSymComplexHerm{T,SA},; permute::Bool=true, scale::Bool=true) = _eigvals(Size(SA), a, permute, scale)
3+
@inline function eigvals(a::StaticArray; permute::Bool=true, scale::Bool=true)
4+
if ishermitian(a)
5+
_eigvals(Size(a), Hermitian(a), permute, scale)
6+
else
7+
error("Only hermitian matrices are diagonalizable by *StaticArrays*. Non-Hermitian matrices should be converted to `Array` first.")
8+
end
9+
end
10+
11+
@inline _eigvals(::Size{(1,1)}, a, permute, scale) = @inbounds return SVector(real(a.data[1]))
12+
13+
@inline function _eigvals{T<:Real}(::Size{(2,2)}, A::Base.LinAlg.RealHermSymComplexHerm{T}, permute, scale)
14+
a = A.data
15+
16+
if A.uplo == 'U'
17+
@inbounds t_half = real(a[1] + a[4])/2
18+
@inbounds d = real(a[1]*a[4] - a[3]'*a[3]) # Should be real
19+
20+
tmp2 = t_half*t_half - d
21+
tmp2 < 0 ? tmp = zero(tmp2) : tmp = sqrt(tmp2) # Numerically stable for identity matrices, etc.
22+
return SVector(t_half - tmp, t_half + tmp)
23+
else
24+
@inbounds t_half = real(a[1] + a[4])/2
25+
@inbounds d = real(a[1]*a[4] - a[2]'*a[2]) # Should be real
26+
27+
tmp2 = t_half*t_half - d
28+
tmp2 < 0 ? tmp = zero(tmp2) : tmp = sqrt(tmp2) # Numerically stable for identity matrices, etc.
29+
return SVector(t_half - tmp, t_half + tmp)
30+
end
31+
end
32+
33+
@inline function _eigvals{T<:Real}(::Size{(3,3)}, A::Base.LinAlg.RealHermSymComplexHerm{T}, permute, scale)
34+
S = typeof((one(T)*zero(T) + zero(T))/one(T))
35+
Sreal = real(S)
36+
37+
@inbounds a11 = convert(Sreal, A.data[1])
38+
@inbounds a22 = convert(Sreal, A.data[5])
39+
@inbounds a33 = convert(Sreal, A.data[9])
40+
if A.uplo == 'U'
41+
@inbounds a12 = convert(S, A.data[4])
42+
@inbounds a13 = convert(S, A.data[7])
43+
@inbounds a23 = convert(S, A.data[8])
44+
else
45+
@inbounds a12 = conj(convert(S, A.data[2]))
46+
@inbounds a13 = conj(convert(S, A.data[3]))
47+
@inbounds a23 = conj(convert(S, A.data[6]))
48+
end
49+
50+
p1 = abs2(a12) + abs2(a13) + abs2(a23)
51+
if (p1 == 0)
52+
# Matrix is diagonal
53+
if a11 < a22
54+
if a22 < a33
55+
return SVector(a11, a22, a33)
56+
elseif a33 < a11
57+
return SVector(a33, a11, a22)
58+
else
59+
return SVector(a11, a33, a22)
60+
end
61+
else #a22 < a11
62+
if a11 < a33
63+
return SVector(a22, a11, a33)
64+
elseif a33 < a22
65+
return SVector(a33, a22, a11)
66+
else
67+
return SVector(a22, a33, a11)
68+
end
69+
end
70+
end
71+
72+
q = (a11 + a22 + a33) / 3
73+
p2 = abs2(a11 - q) + abs2(a22 - q) + abs2(a33 - q) + 2 * p1
74+
p = sqrt(p2 / 6)
75+
invp = inv(p)
76+
b11 = (a11 - q) * invp
77+
b22 = (a22 - q) * invp
78+
b33 = (a33 - q) * invp
79+
b12 = a12 * invp
80+
b13 = a13 * invp
81+
b23 = a23 * invp
82+
B = SMatrix{3,3,S}((b11, conj(b12), conj(b13), b12, b22, conj(b23), b13, b23, b33))
83+
r = real(det(B)) / 2
84+
85+
# In exact arithmetic for a symmetric matrix -1 <= r <= 1
86+
# but computation error can leave it slightly outside this range.
87+
if (r <= -1)
88+
phi = Sreal(pi) / 3
89+
elseif (r >= 1)
90+
phi = zero(Sreal)
91+
else
92+
phi = acos(r) / 3
93+
end
94+
95+
eig3 = q + 2 * p * cos(phi)
96+
eig1 = q + 2 * p * cos(phi + (2*Sreal(pi)/3))
97+
eig2 = 3 * q - eig1 - eig3 # since trace(A) = eig1 + eig2 + eig3
98+
99+
return SVector(eig1, eig2, eig3)
100+
end
101+
102+
103+
104+
1105
@inline function eig(A::StaticMatrix; permute::Bool=true, scale::Bool=true)
2106
_eig(Size(A), A, permute, scale)
3107
end
@@ -106,8 +210,27 @@ end
106210
p1 = abs2(a12) + abs2(a13) + abs2(a23)
107211
if (p1 == 0)
108212
# Matrix is diagonal
109-
# TODO need to sort the eigenvalues
110-
return (SVector(a11, a22, a33), eye(SMatrix{3,3,S}))
213+
v1 = SVector(one(S), zero(S), zero(S))
214+
v2 = SVector(zero(S), one(S), zero(S))
215+
v3 = SVector(zero(S), zero(S), one(S) )
216+
217+
if a11 < a22
218+
if a22 < a33
219+
return (SVector(a11, a22, a33), hcat(v1,v2,v3))
220+
elseif a33 < a11
221+
return (SVector(a33, a11, a22), hcat(v3,v1,v2))
222+
else
223+
return (SVector(a11, a33, a22), hcat(v1,v3,v2))
224+
end
225+
else #a22 < a11
226+
if a11 < a33
227+
return (SVector(a22, a11, a33), hcat(v2,v1,v3))
228+
elseif a33 < a22
229+
return (SVector(a33, a22, a11), hcat(v3,v2,v1))
230+
else
231+
return (SVector(a22, a33, a11), hcat(v2,v3,v1))
232+
end
233+
end
111234
end
112235

113236
q = (a11 + a22 + a33) / 3
@@ -251,6 +374,8 @@ end
251374
#=
252375
Boost Software License - Version 1.0 - August 17th, 2003
253376
377+
378+
254379
Permission is hereby granted, free of charge, to any person or organization
255380
obtaining a copy of the software and accompanying documentation covered by
256381
this license (the "Software") to use, reproduce, display, distribute,

test/eigen.jl

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
m = @SMatrix [2.0]
44
(vals, vecs) = eig(m)
55
@test vals === SVector(2.0)
6+
@test eigvals(m) === vals
67
@test vecs === SMatrix{1,1}(1.0)
78

89
(vals, vecs) = eig(Symmetric(m))
@@ -18,10 +19,12 @@
1819
(vals_a, vecs_a) = eig(m)
1920
(vals, vecs) = eig(m)
2021
@test vals::SVector vals_a
22+
@test eigvals(m) vals
2123
@test (vecs*diagm(vals)*vecs')::SMatrix m
2224

2325
(vals, vecs) = eig(Symmetric(m))
2426
@test vals::SVector vals_a
27+
@test eigvals(m) vals
2528
@test (vecs*diagm(vals)*vecs')::SMatrix m
2629
end
2730

@@ -33,72 +36,81 @@
3336
(vals_a, vecs_a) = eig(m)
3437
(vals, vecs) = eig(m)
3538
@test vals::SVector vals_a
39+
@test eigvals(m) vals
3640
@test (vecs*diagm(vals)*vecs')::SMatrix m
3741

3842
(vals, vecs) = eig(Symmetric(m))
3943
@test vals::SVector vals_a
44+
@test eigvals(m) vals
4045
@test (vecs*diagm(vals)*vecs')::SMatrix m
4146
end
4247

4348
@testset "3x3 degenerate cases" begin
4449
# Rank 1
4550
v = randn(SVector{3,Float64})
4651
m = v*v'
47-
eigvals, eigvecs = eig(m)::Tuple{SVector,SMatrix}
52+
vals, vecs = eig(m)::Tuple{SVector,SMatrix}
4853

49-
@test isapprox(eigvecs'*eigvecs, eye(SMatrix{3,3,Float64}); atol = 1e-4) # This algorithm isn't super accurate
50-
@test eigvals SVector(0.0, 0.0, sumabs2(v))
54+
@test vecs'*vecs eye(SMatrix{3,3,Float64})
55+
@test vals SVector(0.0, 0.0, sumabs2(v))
56+
@test eigvals(m) vals
5157

5258
# Rank 2
5359
v2 = randn(SVector{3,Float64})
5460
v2 -= dot(v,v2)*v/sumabs2(v)
5561
m += v2*v2'
56-
eigvals, eigvecs = eig(m)::Tuple{SVector,SMatrix}
62+
vals, vecs = eig(m)::Tuple{SVector,SMatrix}
5763

58-
@test isapprox(eigvecs'*eigvecs, eye(SMatrix{3,3,Float64}); atol = 1e-4)
64+
@test vecs'*vecs eye(SMatrix{3,3,Float64})
5965
if sumabs2(v) < sumabs2(v2)
60-
@test eigvals SVector(0.0, sumabs2(v), sumabs2(v2))
66+
@test vals SVector(0.0, sumabs2(v), sumabs2(v2))
6167
else
62-
@test eigvals SVector(0.0, sumabs2(v2), sumabs2(v))
68+
@test vals SVector(0.0, sumabs2(v2), sumabs2(v))
6369
end
70+
@test eigvals(m) vals
6471

6572
# Degeneracy (2 large)
6673
m = -99*(v*v')/sumabs2(v) + 100*eye(SMatrix{3,3,Float64})
67-
eigvals, eigvecs = eig(m)::Tuple{SVector,SMatrix}
74+
vals, vecs = eig(m)::Tuple{SVector,SMatrix}
6875

69-
@test isapprox(eigvecs'*eigvecs, eye(SMatrix{3,3,Float64}); atol = 1e-4)
70-
@test eigvals SVector(1.0, 100.0, 100.0)
76+
@test vecs'*vecs eye(SMatrix{3,3,Float64})
77+
@test vals SVector(1.0, 100.0, 100.0)
78+
@test eigvals(m) vals
7179

7280
# Degeneracy (2 small)
7381
m = (v*v')/sumabs2(v) + 1e-2*eye(SMatrix{3,3,Float64})
74-
eigvals, eigvecs = eig(m)::Tuple{SVector,SMatrix}
82+
vals, vecs = eig(m)::Tuple{SVector,SMatrix}
7583

76-
@test isapprox(eigvecs'*eigvecs, eye(SMatrix{3,3,Float64}); atol = 1e-4)
77-
@test eigvals SVector(1e-2, 1e-2, 1.01)
84+
@test vecs'*vecs eye(SMatrix{3,3,Float64})
85+
@test vals SVector(1e-2, 1e-2, 1.01)
86+
@test eigvals(m) vals
7887

7988
# Block diagonal
8089
m = @SMatrix [1.0 0.0 0.0;
8190
0.0 1.0 1.0;
8291
0.0 1.0 1.0]
83-
eigvals, eigvecs = eig(m)::Tuple{SVector,SMatrix}
92+
vals, vecs = eig(m)::Tuple{SVector,SMatrix}
8493

85-
@test eigvals [0.0, 1.0, 2.0]
86-
@test eigvecs*diagm(eigvals)*eigvecs' m
94+
@test vals [0.0, 1.0, 2.0]
95+
@test vecs*diagm(vals)*vecs' m
96+
@test eigvals(m) vals
8797

8898
m = @SMatrix [1.0 0.0 1.0;
8999
0.0 1.0 0.0;
90100
1.0 0.0 1.0]
91-
eigvals, eigvecs = eig(m)::Tuple{SVector,SMatrix}
101+
vals, vecs = eig(m)::Tuple{SVector,SMatrix}
92102

93-
@test eigvals [0.0, 1.0, 2.0]
94-
@test eigvecs*diagm(eigvals)*eigvecs' m
103+
@test vals [0.0, 1.0, 2.0]
104+
@test vecs*diagm(vals)*vecs' m
105+
@test eigvals(m) vals
95106

96107
m = @SMatrix [1.0 1.0 0.0;
97108
1.0 1.0 0.0;
98109
0.0 0.0 1.0]
99-
eigvals, eigvecs = eig(m)::Tuple{SVector,SMatrix}
110+
vals, vecs = eig(m)::Tuple{SVector,SMatrix}
100111

101-
@test eigvals [0.0, 1.0, 2.0]
102-
@test eigvecs*diagm(eigvals)*eigvecs' m
112+
@test vals [0.0, 1.0, 2.0]
113+
@test vecs*diagm(vals)*vecs' m
114+
@test eigvals(m) vals
103115
end
104116
end

0 commit comments

Comments
 (0)