Skip to content

Commit 413ba1b

Browse files
committed
Added check option to cholesky
1 parent 59f92e0 commit 413ba1b

File tree

2 files changed

+39
-11
lines changed

2 files changed

+39
-11
lines changed

src/cholesky.jl

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,33 @@
11
# Generic Cholesky decomposition for fixed-size matrices, mostly unrolled
22
non_hermitian_error() = throw(LinearAlgebra.PosDefException(-1))
3-
@inline function LinearAlgebra.cholesky(A::StaticMatrix)
3+
@inline function LinearAlgebra.cholesky(A::StaticMatrix; check::Bool = true)
44
ishermitian(A) || non_hermitian_error()
5-
C = _cholesky(Size(A), A)
6-
return Cholesky(C, 'U', 0)
5+
_cholesky(Size(A), A, check)
6+
# (check && (info ≠ 0)) && throw(LinearAlgebra.PosDefException(info))
7+
# return Cholesky(C, 'U', info)
78
end
89

9-
@inline function LinearAlgebra.cholesky(A::LinearAlgebra.RealHermSymComplexHerm{<:Real, <:StaticMatrix})
10-
C = _cholesky(Size(A), A.data)
11-
return Cholesky(C, 'U', 0)
10+
@inline function LinearAlgebra.cholesky(A::LinearAlgebra.RealHermSymComplexHerm{<:Real, <:StaticMatrix}; check::Bool = true)
11+
C = _cholesky(Size(A), A.data, check)
12+
# (check && (info ≠ 0)) && throw(LinearAlgebra.PosDefException(info))
13+
# return Cholesky(C, 'U', 0)
1214
end
1315
@inline LinearAlgebra._chol!(A::StaticMatrix, ::Type{UpperTriangular}) = (cholesky(A).U, 0)
1416

15-
@generated function _cholesky(::Size{S}, A::StaticMatrix{M,M}) where {S,M}
17+
@inline function _check_chol(A, info, check)
18+
if check
19+
throw(LinearAlgebra.PosDefException(info))
20+
else
21+
return Cholesky(A, 'U', info)
22+
end
23+
end
24+
@inline _nonpdcheck(x::Real) = x < zero(x)
25+
@inline _nonpdcheck(x) = false
26+
27+
@generated function _cholesky(::Size{S}, A::StaticMatrix{M,M}, check::Bool) where {S,M}
1628
@assert (M,M) == S
1729
M > 24 && return :(_cholesky_large(Size{$S}(), A))
18-
q = Expr(:block)
30+
q = Expr(:block, :(info = 0), :(failure = false))
1931
for n 1:M
2032
for m n:M
2133
L_m_n = Symbol(:L_,m,:_,n)
@@ -28,7 +40,13 @@ end
2840
push!(q.args, :($L_m_n = muladd(-$L_m_k, $L_n_k', $L_m_n)))
2941
end
3042
L_n_n = Symbol(:L_,n,:_,n)
31-
push!(q.args, :($L_n_n = sqrt($L_n_n)))
43+
L_n_n_ltz = Symbol(:L_,n,:_,n,:_,:ltz)
44+
# x < 0.0 is check used in `sqrt`, letting LLVM eliminate that check and remove error code.
45+
# push!(q.args, :($L_n_n_ltz = )
46+
push!(q.args, :($L_n_n = _nonpdcheck($L_n_n) ? (return _check_chol(A, $n, check)) : sqrt($L_n_n)))
47+
# push!(q.args, :(info = ($L_n_n_ltz & (!failure)) ? $n : info))
48+
# push!(q.args, :(failure |= $L_n_n_ltz))
49+
# push!(q.args, :($L_n_n = $L_n_n_ltz ? float(typeof($L_n_n))(NaN) : sqrt($L_n_n)))
3250
Linv_n_n = Symbol(:Linv_,n,:_,n)
3351
push!(q.args, :($Linv_n_n = inv($L_n_n)))
3452
for m n+1:M
@@ -46,13 +64,13 @@ end
4664
push!(ret.args, :(zero(T)))
4765
end
4866
end
49-
push!(q.args, :(similar_type(A, T)($ret)))
67+
push!(q.args, :(Cholesky(similar_type(A, T)($ret), 'U', 0)))
5068
return Expr(:block, Expr(:meta, :inline), q)
5169
end
5270

5371
# Otherwise default algorithm returning wrapped SizedArray
5472
@inline _cholesky_large(::Size{S}, A::StaticArray) where {S} =
55-
similar_type(A)(cholesky(Hermitian(Matrix(A))).U)
73+
Cholesky(similar_type(A)(cholesky(Hermitian(Matrix(A))).U), 'U', 0)
5674

5775
LinearAlgebra.hermitian_type(::Type{SA}) where {T, S, SA<:SArray{S,T}} = Hermitian{T,SA}
5876

test/chol.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@ using LinearAlgebra: PosDefException
3434
m_a = randn(elty, 4,4)
3535
#non hermitian
3636
@test_throws PosDefException cholesky(SMatrix{4,4}(m_a))
37+
nonpd = SVector{4}(@view(m_a[:,1])) |> x -> x * x'
38+
if elty <: Real
39+
@test_throws PosDefException cholesky(nonpd)
40+
@test !issuccess(cholesky(nonpd,check=false))
41+
@test_throws PosDefException cholesky(Hermitian(nonpd))
42+
@test !issuccess(cholesky(Symmetric(nonpd),check=false))
43+
else
44+
@test issuccess(cholesky(Hermitian(nonpd),check=false))
45+
end
46+
3747
m_a = m_a*m_a'
3848
m = SMatrix{4,4}(m_a)
3949
@test cholesky(m).L cholesky(m_a).L

0 commit comments

Comments
 (0)