Skip to content

Commit 53b4a0e

Browse files
authored
Add singularity checks for lu, method for issuccess. (#766)
This makes the implementation of `lu` conform to the API of LinearAlgebra: check by default, optionally skip check. Tests and methods that use `lu` are modified accordingly.
1 parent 93ec354 commit 53b4a0e

File tree

4 files changed

+47
-12
lines changed

4 files changed

+47
-12
lines changed

src/StaticArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import Base: sqrt, exp, log
1717
using LinearAlgebra
1818
import LinearAlgebra: transpose, adjoint, dot, eigvals, eigen, lyap, tr,
1919
kron, diag, norm, dot, diagm, lu, svd, svdvals,
20-
factorize, ishermitian, issymmetric, isposdef, normalize,
20+
factorize, ishermitian, issymmetric, isposdef, issuccess, normalize,
2121
normalize!, Eigen, det, logdet, cross, diff, qr, \
2222
using LinearAlgebra: checksquare
2323

src/det.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ end
4747
if prod(S) 14*14
4848
quote
4949
@_inline_meta
50-
LUp = lu(A)
51-
det(LUp.U)*_parity(LUp.p)
50+
_, U, p = _lu(A, Val(true), false)
51+
det(UpperTriangular(U))*_parity(p)
5252
end
5353
else
5454
:(@_inline_meta; det(Matrix(A)))
@@ -62,7 +62,7 @@ end
6262
if prod(S) 14*14
6363
quote
6464
@_inline_meta
65-
LUp = lu(A)
65+
LUp = lu(A; check = false)
6666
d, s = logabsdet(LUp.U)
6767
d + log(s*_parity(LUp.p))
6868
end

src/lu.jl

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,25 +30,54 @@ function Base.show(io::IO, mime::MIME{Symbol("text/plain")}, F::LU)
3030
end
3131

3232
# LU decomposition
33-
function lu(A::StaticMatrix, pivot::Union{Val{false},Val{true}}=Val(true))
34-
L, U, p = _lu(A, pivot)
33+
function lu(A::StaticMatrix, pivot::Union{Val{false},Val{true}}=Val(true); check = true)
34+
L, U, p = _lu(A, pivot, check)
3535
LU(L, U, p)
3636
end
3737

3838
# For the square version, return explicit lower and upper triangular matrices.
3939
# We would do this for the rectangular case too, but Base doesn't support that.
40-
function lu(A::StaticMatrix{N,N}, pivot::Union{Val{false},Val{true}}=Val(true)) where {N}
41-
L, U, p = _lu(A, pivot)
40+
function lu(A::StaticMatrix{N,N}, pivot::Union{Val{false},Val{true}}=Val(true);
41+
check = true) where {N}
42+
L, U, p = _lu(A, pivot, check)
4243
LU(LowerTriangular(L), UpperTriangular(U), p)
4344
end
4445

45-
@generated function _lu(A::StaticMatrix{M,N,T}, pivot) where {M,N,T}
46+
# location of the first zero on the diagonal, 0 when not found
47+
function _first_zero_on_diagonal(A::StaticMatrix{M,N,T}) where {M,N,T}
48+
if @generated
49+
quote
50+
$(map(i -> :(A[$i, $i] == zero(T) && return $i), 1:min(M, N))...)
51+
0
52+
end
53+
else
54+
for i in 1:min(M, N)
55+
A[i, i] == 0 && return i
56+
end
57+
0
58+
end
59+
end
60+
61+
function _first_zero_on_diagonal(A::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix})
62+
_first_zero_on_diagonal(A.data)
63+
end
64+
65+
issuccess(F::LU) = _first_zero_on_diagonal(F.U) == 0
66+
67+
@generated function _lu(A::StaticMatrix{M,N,T}, pivot, check) where {M,N,T}
4668
if M*N 14*14
47-
:(__lu(A, pivot))
69+
quote
70+
L, U, P = __lu(A, pivot)
71+
if check
72+
i = _first_zero_on_diagonal(U)
73+
i == 0 || throw(SingularException(i))
74+
end
75+
L, U, P
76+
end
4877
else
4978
quote
5079
# call through to Base to avoid excessive time spent on type inference for large matrices
51-
f = lu(Matrix(A), pivot; check = false)
80+
f = lu(Matrix(A), pivot; check = check)
5281
# Trick to get the output eltype - can't rely on the result of f.L as
5382
# it's not type inferrable.
5483
T2 = arithmetic_closure(T)

test/lu.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ end
1111

1212
@testset "LU decomposition ($m×$n, pivot=$pivot)" for pivot in (true, false), m in [0:4..., 15], n in [0:4..., 15]
1313
a = SMatrix{m,n,Int}(1:(m*n))
14-
l, u, p = @inferred(lu(a, Val{pivot}()))
14+
l, u, p = @inferred(lu(a, Val{pivot}(); check = false))
1515

1616
# expected types
1717
@test p isa SVector{m,Int}
@@ -57,5 +57,11 @@ end
5757
# test if / and \ work with lu:
5858
@test a\b_col a_lu\b_col
5959
@test b_line/a b_line/a_lu
60+
end
6061

62+
@testset "LU singularity check" for m in [2, 3, 20], n in [2, 3, 20]
63+
# NOTE: large dimensions test fallback to LinearAlgebra.lu
64+
A = ones(SMatrix{m,n})
65+
@test_throws SingularException lu(A)
66+
@test !issuccess(lu(A; check = false))
6167
end

0 commit comments

Comments
 (0)