diff --git a/src/eigen.jl b/src/eigen.jl index e66aea1d..5e1efc4a 100644 --- a/src/eigen.jl +++ b/src/eigen.jl @@ -31,22 +31,11 @@ end @inline function _eigvals(::Size{(2,2)}, A::LinearAlgebra.RealHermSymComplexHerm{T}, permute, scale) where {T <: Real} a = A.data - - if A.uplo == 'U' - @inbounds t_half = real(a[1] + a[4])/2 - @inbounds d = real(a[1]*a[4] - a[3]'*a[3]) # Should be real - - tmp2 = t_half*t_half - d - tmp2 < 0 ? tmp = zero(tmp2) : tmp = sqrt(tmp2) # Numerically stable for identity matrices, etc. - return SVector(t_half - tmp, t_half + tmp) - else - @inbounds t_half = real(a[1] + a[4])/2 - @inbounds d = real(a[1]*a[4] - a[2]'*a[2]) # Should be real - - tmp2 = t_half*t_half - d - tmp2 < 0 ? tmp = zero(tmp2) : tmp = sqrt(tmp2) # Numerically stable for identity matrices, etc. - return SVector(t_half - tmp, t_half + tmp) - end + @inbounds t_half = (real(a[1]) + real(a[4])) / 2 + @inbounds s_half = (real(a[1]) - real(a[4])) / 2 + @inbounds tmp2 = A.uplo == 'U' ? s_half^2 + abs2(a[3]) : s_half^2 + abs2(a[2]) # expansion of (tr(A)^2 - 4*det(A)) / 4 + tmp = sqrt(tmp2) # normally, tmp > 0 if tmp2 > 0 + return SVector(t_half - tmp, t_half + tmp) end @inline function _eigvals(::Size{(3,3)}, A::LinearAlgebra.RealHermSymComplexHerm{T}, permute, scale) where {T <: Real} @@ -158,20 +147,20 @@ end @inbounds if A.uplo == 'U' if !iszero(a[3]) # A is not diagonal - t_half = real(a[1] + a[4]) / 2 - d = real(a[1] * a[4] - a[3]' * a[3]) # Should be real - - tmp2 = t_half * t_half - d - tmp = tmp2 < 0 ? zero(tmp2) : sqrt(tmp2) # Numerically stable for identity matrices, etc. + t_half = (real(a[1]) + real(a[4])) / 2 + s_half = (real(a[1]) - real(a[4])) / 2 + tmp2 = s_half^2 + abs2(a[3]) # expansion of (tr(A)^2 - 4*det(A)) / 4 + # use abs2 to ensure that it is real and > 0 (not sure about x'x being always > 0 if a[3] ≠ 0) + tmp = sqrt(tmp2) # normally, tmp > 0 if tmp2 > 0 vals = SVector(t_half - tmp, t_half + tmp) - v11 = vals[1] - a[4] - n1 = sqrt(v11' * v11 + a[3]' * a[3]) + v11 = s_half - tmp + n1 = sqrt(abs2(v11) + abs2(a[3])) v11 = v11 / n1 v12 = a[3]' / n1 - v21 = vals[2] - a[4] - n2 = sqrt(v21' * v21 + a[3]' * a[3]) + v21 = s_half + tmp + n2 = sqrt(abs2(v21) + abs2(a[3])) v21 = v21 / n2 v22 = a[3]' / n2 @@ -182,27 +171,27 @@ end end else # A.uplo == 'L' if !iszero(a[2]) # A is not diagonal - t_half = real(a[1] + a[4]) / 2 - d = real(a[1] * a[4] - a[2]' * a[2]) # Should be real - - tmp2 = t_half * t_half - d - tmp = tmp2 < 0 ? zero(tmp2) : sqrt(tmp2) # Numerically stable for identity matrices, etc. + t_half = (real(a[1]) + real(a[4])) / 2 + s_half = (real(a[1]) - real(a[4])) / 2 + tmp2 = s_half^2 + abs2(a[2]) # expansion of (tr(A)^2 - 4*det(A)) / 4 + # use abs2 to ensure that it is real and > 0 (not sure about x'x being always > 0 if a[3] ≠ 0) + tmp = sqrt(tmp2) # normally, tmp > 0 if tmp2 > 0 vals = SVector(t_half - tmp, t_half + tmp) - v11 = vals[1] - a[4] - n1 = sqrt(v11' * v11 + a[2]' * a[2]) + v11 = s_half - tmp + n1 = sqrt(abs2(v11) + abs2(a[3])) v11 = v11 / n1 v12 = a[2] / n1 - v21 = vals[2] - a[4] - n2 = sqrt(v21' * v21 + a[2]' * a[2]) + v21 = s_half + tmp + n2 = sqrt(abs2(v21) + abs2(a[3])) v21 = v21 / n2 v22 = a[2] / n2 vecs = @SMatrix [ v11 v21 ; v12 v22 ] - return Eigen(vals,vecs) + return Eigen(vals, vecs) end end