Skip to content

Commit 3d2d0aa

Browse files
authored
Merge pull request #886 from chriselrod/checkchol
Added check option to cholesky
2 parents 59f92e0 + 3075011 commit 3d2d0aa

File tree

2 files changed

+41
-12
lines changed

2 files changed

+41
-12
lines changed

src/cholesky.jl

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,29 @@
11
# Generic Cholesky decomposition for fixed-size matrices, mostly unrolled
22
non_hermitian_error() = throw(LinearAlgebra.PosDefException(-1))
3-
@inline function LinearAlgebra.cholesky(A::StaticMatrix)
3+
@inline function LinearAlgebra.cholesky(A::StaticMatrix; check::Bool = true)
44
ishermitian(A) || non_hermitian_error()
5-
C = _cholesky(Size(A), A)
6-
return Cholesky(C, 'U', 0)
5+
_cholesky(Size(A), A, check)
76
end
87

9-
@inline function LinearAlgebra.cholesky(A::LinearAlgebra.RealHermSymComplexHerm{<:Real, <:StaticMatrix})
10-
C = _cholesky(Size(A), A.data)
11-
return Cholesky(C, 'U', 0)
8+
@inline function LinearAlgebra.cholesky(A::LinearAlgebra.RealHermSymComplexHerm{<:Real, <:StaticMatrix}; check::Bool = true)
9+
_cholesky(Size(A), A.data, check)
1210
end
1311
@inline LinearAlgebra._chol!(A::StaticMatrix, ::Type{UpperTriangular}) = (cholesky(A).U, 0)
1412

15-
@generated function _cholesky(::Size{S}, A::StaticMatrix{M,M}) where {S,M}
13+
@inline function _chol_failure(A, info, check)
14+
if check
15+
throw(LinearAlgebra.PosDefException(info))
16+
else
17+
Cholesky(A, 'U', info)
18+
end
19+
end
20+
# x < zero(x) is check used in `sqrt`, letting LLVM eliminate that check and remove error code.
21+
@inline _nonpdcheck(x::Real) = x zero(x)
22+
@inline _nonpdcheck(x) = x == x
23+
24+
@generated function _cholesky(::Size{S}, A::StaticMatrix{M,M}, check::Bool) where {S,M}
1625
@assert (M,M) == S
17-
M > 24 && return :(_cholesky_large(Size{$S}(), A))
26+
M > 24 && return :(_cholesky_large(Size{$S}(), A, check))
1827
q = Expr(:block)
1928
for n 1:M
2029
for m n:M
@@ -28,7 +37,9 @@ end
2837
push!(q.args, :($L_m_n = muladd(-$L_m_k, $L_n_k', $L_m_n)))
2938
end
3039
L_n_n = Symbol(:L_,n,:_,n)
31-
push!(q.args, :($L_n_n = sqrt($L_n_n)))
40+
L_n_n_ltz = Symbol(:L_,n,:_,n,:_,:ltz)
41+
push!(q.args, :(_nonpdcheck($L_n_n) || return _chol_failure(A, $n, check)))
42+
push!(q.args, :($L_n_n = Base.FastMath.sqrt_fast($L_n_n)))
3243
Linv_n_n = Symbol(:Linv_,n,:_,n)
3344
push!(q.args, :($Linv_n_n = inv($L_n_n)))
3445
for m n+1:M
@@ -46,13 +57,15 @@ end
4657
push!(ret.args, :(zero(T)))
4758
end
4859
end
49-
push!(q.args, :(similar_type(A, T)($ret)))
60+
push!(q.args, :(Cholesky(similar_type(A, T)($ret), 'U', 0)))
5061
return Expr(:block, Expr(:meta, :inline), q)
5162
end
5263

5364
# Otherwise default algorithm returning wrapped SizedArray
54-
@inline _cholesky_large(::Size{S}, A::StaticArray) where {S} =
55-
similar_type(A)(cholesky(Hermitian(Matrix(A))).U)
65+
@inline function _cholesky_large(::Size{S}, A::StaticArray, check::Bool) where {S}
66+
C = cholesky(Hermitian(Matrix(A)); check=check)
67+
Cholesky(similar_type(A)(C.U), 'U', C.info)
68+
end
5669

5770
LinearAlgebra.hermitian_type(::Type{SA}) where {T, S, SA<:SArray{S,T}} = Hermitian{T,SA}
5871

test/chol.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ using LinearAlgebra: PosDefException
3434
m_a = randn(elty, 4,4)
3535
#non hermitian
3636
@test_throws PosDefException cholesky(SMatrix{4,4}(m_a))
37+
3738
m_a = m_a*m_a'
3839
m = SMatrix{4,4}(m_a)
3940
@test cholesky(m).L cholesky(m_a).L
@@ -86,6 +87,21 @@ using LinearAlgebra: PosDefException
8687
@test (@inferred c \ v) isa SVector{3,elty}
8788
@test c \ v c_a \ v_a
8889
end
90+
91+
@testset "Check" begin
92+
for i [1,3,7,25]
93+
x = SVector(ntuple(elty, i))
94+
nonpd = x * x'
95+
if i > 1
96+
@test_throws PosDefException cholesky(nonpd)
97+
@test !issuccess(cholesky(nonpd,check=false))
98+
@test_throws PosDefException cholesky(Hermitian(nonpd))
99+
@test !issuccess(cholesky(Hermitian(nonpd),check=false))
100+
else
101+
@test issuccess(cholesky(Hermitian(nonpd),check=false))
102+
end
103+
end
104+
end
89105
end
90106

91107
@testset "static blockmatrix" for i = 1:10

0 commit comments

Comments
 (0)