|
1 | 1 | # Generic Cholesky decomposition for fixed-size matrices, mostly unrolled
|
2 | 2 | non_hermitian_error() = throw(LinearAlgebra.PosDefException(-1))
|
3 |
| -@inline function LinearAlgebra.cholesky(A::StaticMatrix) |
| 3 | +@inline function LinearAlgebra.cholesky(A::StaticMatrix; check::Bool = true) |
4 | 4 | ishermitian(A) || non_hermitian_error()
|
5 |
| - C = _cholesky(Size(A), A) |
6 |
| - return Cholesky(C, 'U', 0) |
| 5 | + _cholesky(Size(A), A, check) |
7 | 6 | end
|
8 | 7 |
|
9 |
| -@inline function LinearAlgebra.cholesky(A::LinearAlgebra.RealHermSymComplexHerm{<:Real, <:StaticMatrix}) |
10 |
| - C = _cholesky(Size(A), A.data) |
11 |
| - return Cholesky(C, 'U', 0) |
| 8 | +@inline function LinearAlgebra.cholesky(A::LinearAlgebra.RealHermSymComplexHerm{<:Real, <:StaticMatrix}; check::Bool = true) |
| 9 | + _cholesky(Size(A), A.data, check) |
12 | 10 | end
|
13 | 11 | @inline LinearAlgebra._chol!(A::StaticMatrix, ::Type{UpperTriangular}) = (cholesky(A).U, 0)
|
14 | 12 |
|
15 |
| -@generated function _cholesky(::Size{S}, A::StaticMatrix{M,M}) where {S,M} |
| 13 | +@inline function _chol_failure(A, info, check) |
| 14 | + if check |
| 15 | + throw(LinearAlgebra.PosDefException(info)) |
| 16 | + else |
| 17 | + Cholesky(A, 'U', info) |
| 18 | + end |
| 19 | +end |
| 20 | +# x < zero(x) is check used in `sqrt`, letting LLVM eliminate that check and remove error code. |
| 21 | +@inline _nonpdcheck(x::Real) = x ≥ zero(x) |
| 22 | +@inline _nonpdcheck(x) = x == x |
| 23 | + |
| 24 | +@generated function _cholesky(::Size{S}, A::StaticMatrix{M,M}, check::Bool) where {S,M} |
16 | 25 | @assert (M,M) == S
|
17 |
| - M > 24 && return :(_cholesky_large(Size{$S}(), A)) |
| 26 | + M > 24 && return :(_cholesky_large(Size{$S}(), A, check)) |
18 | 27 | q = Expr(:block)
|
19 | 28 | for n ∈ 1:M
|
20 | 29 | for m ∈ n:M
|
|
28 | 37 | push!(q.args, :($L_m_n = muladd(-$L_m_k, $L_n_k', $L_m_n)))
|
29 | 38 | end
|
30 | 39 | L_n_n = Symbol(:L_,n,:_,n)
|
31 |
| - push!(q.args, :($L_n_n = sqrt($L_n_n))) |
| 40 | + L_n_n_ltz = Symbol(:L_,n,:_,n,:_,:ltz) |
| 41 | + push!(q.args, :(_nonpdcheck($L_n_n) || return _chol_failure(A, $n, check))) |
| 42 | + push!(q.args, :($L_n_n = Base.FastMath.sqrt_fast($L_n_n))) |
32 | 43 | Linv_n_n = Symbol(:Linv_,n,:_,n)
|
33 | 44 | push!(q.args, :($Linv_n_n = inv($L_n_n)))
|
34 | 45 | for m ∈ n+1:M
|
|
46 | 57 | push!(ret.args, :(zero(T)))
|
47 | 58 | end
|
48 | 59 | end
|
49 |
| - push!(q.args, :(similar_type(A, T)($ret))) |
| 60 | + push!(q.args, :(Cholesky(similar_type(A, T)($ret), 'U', 0))) |
50 | 61 | return Expr(:block, Expr(:meta, :inline), q)
|
51 | 62 | end
|
52 | 63 |
|
53 | 64 | # Otherwise default algorithm returning wrapped SizedArray
|
54 |
| -@inline _cholesky_large(::Size{S}, A::StaticArray) where {S} = |
55 |
| - similar_type(A)(cholesky(Hermitian(Matrix(A))).U) |
| 65 | +@inline function _cholesky_large(::Size{S}, A::StaticArray, check::Bool) where {S} |
| 66 | + C = cholesky(Hermitian(Matrix(A)); check=check) |
| 67 | + Cholesky(similar_type(A)(C.U), 'U', C.info) |
| 68 | +end |
56 | 69 |
|
57 | 70 | LinearAlgebra.hermitian_type(::Type{SA}) where {T, S, SA<:SArray{S,T}} = Hermitian{T,SA}
|
58 | 71 |
|
|
0 commit comments