@@ -30,25 +30,54 @@ function Base.show(io::IO, mime::MIME{Symbol("text/plain")}, F::LU)
30
30
end
31
31
32
32
# LU decomposition
33
- function lu (A:: StaticMatrix , pivot:: Union{Val{false},Val{true}} = Val (true ))
34
- L, U, p = _lu (A, pivot)
33
+ function lu (A:: StaticMatrix , pivot:: Union{Val{false},Val{true}} = Val (true ); check = true )
34
+ L, U, p = _lu (A, pivot, check )
35
35
LU (L, U, p)
36
36
end
37
37
38
38
# For the square version, return explicit lower and upper triangular matrices.
39
39
# We would do this for the rectangular case too, but Base doesn't support that.
40
- function lu (A:: StaticMatrix{N,N} , pivot:: Union{Val{false},Val{true}} = Val (true )) where {N}
41
- L, U, p = _lu (A, pivot)
40
+ function lu (A:: StaticMatrix{N,N} , pivot:: Union{Val{false},Val{true}} = Val (true );
41
+ check = true ) where {N}
42
+ L, U, p = _lu (A, pivot, check)
42
43
LU (LowerTriangular (L), UpperTriangular (U), p)
43
44
end
44
45
45
- @generated function _lu (A:: StaticMatrix{M,N,T} , pivot) where {M,N,T}
46
+ # location of the first zero on the diagonal, 0 when not found
47
+ function _first_zero_on_diagonal (A:: StaticMatrix{M,N,T} ) where {M,N,T}
48
+ if @generated
49
+ quote
50
+ $ (map (i -> :(A[$ i, $ i] == zero (T) && return $ i), 1 : min (M, N))... )
51
+ 0
52
+ end
53
+ else
54
+ for i in 1 : min (M, N)
55
+ A[i, i] == 0 && return i
56
+ end
57
+ 0
58
+ end
59
+ end
60
+
61
+ function _first_zero_on_diagonal (A:: LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix} )
62
+ _first_zero_on_diagonal (A. data)
63
+ end
64
+
65
+ issuccess (F:: LU ) = _first_zero_on_diagonal (F. U) == 0
66
+
67
+ @generated function _lu (A:: StaticMatrix{M,N,T} , pivot, check) where {M,N,T}
46
68
if M* N ≤ 14 * 14
47
- :(__lu (A, pivot))
69
+ quote
70
+ L, U, P = __lu (A, pivot)
71
+ if check
72
+ i = _first_zero_on_diagonal (U)
73
+ i == 0 || throw (SingularException (i))
74
+ end
75
+ L, U, P
76
+ end
48
77
else
49
78
quote
50
79
# call through to Base to avoid excessive time spent on type inference for large matrices
51
- f = lu (Matrix (A), pivot; check = false )
80
+ f = lu (Matrix (A), pivot; check = check )
52
81
# Trick to get the output eltype - can't rely on the result of f.L as
53
82
# it's not type inferrable.
54
83
T2 = arithmetic_closure (T)
0 commit comments