Skip to content

Commit 8ecb78a

Browse files
author
Andy Ferris
committed
Fixed 3x3 eigevalue problem with degeneracies
The 3x3 eigenvector solver wasn't working whenever there was degeneracies. Now most of the corner cases have been solved.
1 parent a430568 commit 8ecb78a

File tree

2 files changed

+118
-43
lines changed

2 files changed

+118
-43
lines changed

src/eigen.jl

Lines changed: 81 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -84,66 +84,104 @@ end
8484
end
8585

8686
# TODO fix for complex case
87-
@generated function _eig{T<:Real}(::Size{(3,3)}, A::Base.LinAlg.RealHermSymComplexHerm{T}, permute, scale)
87+
@inline function _eig{T<:Real}(::Size{(3,3)}, A::Base.LinAlg.RealHermSymComplexHerm{T}, permute, scale)
8888
S = typeof((one(T)*zero(T) + zero(T))/one(T))
8989

90-
return quote
91-
$(Expr(:meta, :inline))
90+
uplo = A.uplo
91+
data = A.data
92+
if uplo == 'U'
93+
@inbounds Afull = SMatrix{3,3}(data[1], data[4], data[7], data[4], data[5], data[8], data[7], data[8], data[9])
94+
else
95+
@inbounds Afull = SMatrix{3,3}(data[1], data[2], data[3], data[2], data[5], data[6], data[3], data[6], data[9])
96+
end
97+
98+
# Adapted from Wikipedia
99+
@inbounds p1 = Afull[4]*Afull[4] + Afull[7]*Afull[7] + Afull[8]*Afull[8]
100+
if (p1 == 0)
101+
# Afull is diagonal.
102+
@inbounds eig1 = Afull[1]
103+
@inbounds eig2 = Afull[5]
104+
@inbounds eig3 = Afull[9]
92105

93-
uplo = A.uplo
94-
data = A.data
95-
if uplo == 'U'
96-
@inbounds Afull = SMatrix{3,3}(data[1], data[4], data[7], data[4], data[5], data[8], data[7], data[8], data[9])
106+
return (SVector{3,S}(eig1, eig2, eig3), eye(SMatrix{3,3,S}))
107+
else
108+
q = trace(Afull)/3
109+
@inbounds p2 = (Afull[1] - q)^2 + (Afull[5] - q)^2 + (Afull[9] - q)^2 + 2 * p1
110+
p = sqrt(p2 / 6)
111+
B = (1 / p) * (Afull - UniformScaling(q)) # q*I
112+
r = det(B) / 2
113+
114+
# In exact arithmetic for a symmetric matrix -1 <= r <= 1
115+
# but computation error can leave it slightly outside this range.
116+
if (r <= -1)
117+
phi = S(pi) / 3
118+
elseif (r >= 1)
119+
phi = zero(S)
97120
else
98-
@inbounds Afull = SMatrix{3,3}(data[1], data[2], data[3], data[2], data[5], data[6], data[3], data[6], data[9])
121+
phi = acos(r) / 3
99122
end
100123

101-
# Adapted from Wikipedia
102-
@inbounds p1 = Afull[4]*Afull[4] + Afull[7]*Afull[7] + Afull[8]*Afull[8]
103-
if (p1 == 0)
104-
# Afull is diagonal.
105-
@inbounds eig1 = Afull[1]
106-
@inbounds eig2 = Afull[5]
107-
@inbounds eig3 = Afull[9]
108-
109-
return (SVector{3,$S}(eig1, eig2, eig3), eye(SMatrix{3,3,$S}))
110-
else
111-
q = trace(Afull)/3
112-
@inbounds p2 = (Afull[1] - q)^2 + (Afull[5] - q)^2 + (Afull[9] - q)^2 + 2 * p1
113-
p = sqrt(p2 / 6)
114-
B = (1 / p) * (Afull - UniformScaling(q)) # q*I
115-
r = det(B) / 2
116-
117-
# In exact arithmetic for a symmetric matrix -1 <= r <= 1
118-
# but computation error can leave it slightly outside this range.
119-
if (r <= -1) # TODO what type should phi be?
120-
phi = pi / 3
121-
elseif (r >= 1)
122-
phi = 0.0
124+
# the eigenvalues satisfy eig1 <= eig2 <= eig3
125+
eig3 = q + 2 * p * cos(phi)
126+
eig1 = q + 2 * p * cos(phi + (2*pi/3))
127+
eig2 = 3 * q - eig1 - eig3 # since trace(Afull) = eig1 + eig2 + eig3
128+
129+
# Now get the eigenvectors
130+
131+
# To avoid problems with double degeneracies, we tackle the most distinct
132+
# eigenvalue first
133+
if eig2 - eig1 > eig3 - eig2
134+
# The first eigenvalue is "most distinct"
135+
@inbounds tmp1 = SVector(Afull[1] - eig3, Afull[2], Afull[3])
136+
@inbounds tmp2 = SVector(Afull[4], Afull[5] - eig3, Afull[6])
137+
v3 = cross(tmp1, tmp2)
138+
n3 = vecnorm(v3)
139+
v3 = v3 / n3
140+
141+
# Find the second one from this one
142+
@inbounds tmp3 = normalize(SVector(Afull[1] - eig2, Afull[2], Afull[3]))
143+
@inbounds tmp4 = normalize(SVector(Afull[4], Afull[5] - eig2, Afull[6]))
144+
v2_1 = cross(tmp3, v3)
145+
v2_2 = cross(tmp4, v3)
146+
n2_1 = vecnorm(v2_1)
147+
n2_2 = vecnorm(v2_2)
148+
if n2_1 > n2_2
149+
v2 = v2_1 / n2_1
123150
else
124-
phi = acos(r) / 3
151+
v2 = v2_2 / n2_2
125152
end
126153

127-
# the eigenvalues satisfy eig1 <= eig2 <= eig3
128-
eig3 = q + 2 * p * cos(phi)
129-
eig1 = q + 2 * p * cos(phi + (2*pi/3))
130-
eig2 = 3 * q - eig1 - eig3 # since trace(Afull) = eig1 + eig2 + eig3
154+
# The third is easy
155+
v1 = cross(v2, v3) # should be normalized already
131156

132-
# Now get the eigenvectors
133-
# TODO branch for when eig1 == eig2?
157+
@inbounds return (SVector((eig1, eig2, eig3)), SMatrix{3,3}((v1[1], v1[2], v1[3], v2[1], v2[2], v2[3], v3[1], v3[2], v3[3])))
158+
else
159+
# The third eigenvalue is "most distinct"
134160
@inbounds tmp1 = SVector(Afull[1] - eig1, Afull[2], Afull[3])
135161
@inbounds tmp2 = SVector(Afull[4], Afull[5] - eig1, Afull[6])
136162
v1 = cross(tmp1, tmp2)
137-
v1 = v1 / vecnorm(v1)
138-
139-
@inbounds tmp1 = SVector(Afull[1] - eig2, Afull[2], Afull[3])
140-
@inbounds tmp2 = SVector(Afull[4], Afull[5] - eig2, Afull[6])
141-
v2 = cross(tmp1, tmp2)
142-
v2 = v2 / vecnorm(v2)
163+
n1 = vecnorm(v1)
164+
v1 = v1 / n1
165+
166+
# Find the second one from this one
167+
@inbounds tmp3 = normalize(SVector(Afull[1] - eig2, Afull[2], Afull[3]))
168+
@inbounds tmp4 = normalize(SVector(Afull[4], Afull[5] - eig2, Afull[6]))
169+
v2_1 = cross(tmp3, v1)
170+
v2_2 = cross(tmp4, v1)
171+
n2_1 = vecnorm(v2_1)
172+
n2_2 = vecnorm(v2_2)
173+
if n2_1 > n2_2
174+
v2 = v2_1 / n2_1
175+
else
176+
v2 = v2_2 / n2_2
177+
end
143178

179+
# The third is easy
144180
v3 = cross(v1, v2) # should be normalized already
145181

146182
@inbounds return (SVector((eig1, eig2, eig3)), SMatrix{3,3}((v1[1], v1[2], v1[3], v2[1], v2[2], v2[3], v3[1], v3[2], v3[3])))
147183
end
184+
185+
@inbounds return (SVector((eig1, eig2, eig3)), SMatrix{3,3}((v1[1], v1[2], v1[3], v2[1], v2[2], v2[3], v3[1], v3[2], v3[3])))
148186
end
149187
end

test/eigen.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,41 @@
3939
@test vals::SVector vals_a
4040
@test (vecs*diagm(vals)*vecs')::SMatrix m
4141
end
42+
43+
@testset "3x3 degenerate cases" begin
44+
# Rank 1
45+
v = randn(SVector{3,Float64})
46+
m = v*v'
47+
eigvals, eigvecs = eig(m)::Tuple{SVector,SMatrix}
48+
49+
@test isapprox(eigvecs'*eigvecs, eye(SMatrix{3,3,Float64}); atol = 1e-4) # This algorithm isn't super accurate
50+
@test eigvals SVector(0.0, 0.0, sumabs2(v))
51+
52+
# Rank 2
53+
v2 = randn(SVector{3,Float64})
54+
v2 -= dot(v,v2)*v/sumabs2(v)
55+
m += v2*v2'
56+
eigvals, eigvecs = eig(m)::Tuple{SVector,SMatrix}
57+
58+
@test isapprox(eigvecs'*eigvecs, eye(SMatrix{3,3,Float64}); atol = 1e-4)
59+
if sumabs2(v) < sumabs2(v2)
60+
@test eigvals SVector(0.0, sumabs2(v), sumabs2(v2))
61+
else
62+
@test eigvals SVector(0.0, sumabs2(v2), sumabs2(v))
63+
end
64+
65+
# Degeneracy (2 large)
66+
m = -99*(v*v')/sumabs2(v) + 100*eye(SMatrix{3,3,Float64})
67+
eigvals, eigvecs = eig(m)::Tuple{SVector,SMatrix}
68+
69+
@test isapprox(eigvecs'*eigvecs, eye(SMatrix{3,3,Float64}); atol = 1e-4)
70+
@test eigvals SVector(1.0, 100.0, 100.0)
71+
72+
# Degeneracy (2 small)
73+
m = (v*v')/sumabs2(v) + 1e-2*eye(SMatrix{3,3,Float64})
74+
eigvals, eigvecs = eig(m)::Tuple{SVector,SMatrix}
75+
76+
@test isapprox(eigvecs'*eigvecs, eye(SMatrix{3,3,Float64}); atol = 1e-4)
77+
@test eigvals SVector(1e-2, 1e-2, 1.01)
78+
end
4279
end

0 commit comments

Comments
 (0)