Skip to content

Commit 5dfbe09

Browse files
authored
LU decomposition for symmetric and Hermitian static matrices (#972)
* LU decomposition for symmetric and Hermitian static matrices * bump version * fix for recent LU changes * more LU to test group B
1 parent bb8577a commit 5dfbe09

File tree

4 files changed

+26
-13
lines changed

4 files changed

+26
-13
lines changed

src/lu.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ function Base.show(io::IO, mime::MIME{Symbol("text/plain")}, F::LU)
2929
show(io, mime, F.U)
3030
end
3131

32+
const StaticLUMatrix{N,M,T} = Union{StaticMatrix{N,M,T}, Symmetric{T,<:StaticMatrix{N,M,T}}, Hermitian{T,<:StaticMatrix{N,M,T}}}
33+
3234
# LU decomposition
3335
pivot_options = if isdefined(LinearAlgebra, :PivotingStrategy) # introduced in Julia v1.7
3436
(:(Val{true}), :(Val{false}), :NoPivot, :RowMaximum)
@@ -37,22 +39,22 @@ else
3739
end
3840
for pv in pivot_options
3941
# ... define each `pivot::Val{true/false}` method individually to avoid ambiguties
40-
@eval function lu(A::StaticMatrix, pivot::$pv; check = true)
42+
@eval function lu(A::StaticLUMatrix, pivot::$pv; check = true)
4143
L, U, p = _lu(A, pivot, check)
4244
LU(L, U, p)
4345
end
4446

4547
# For the square version, return explicit lower and upper triangular matrices.
4648
# We would do this for the rectangular case too, but Base doesn't support that.
47-
@eval function lu(A::StaticMatrix{N,N}, pivot::$pv; check = true) where {N}
49+
@eval function lu(A::StaticLUMatrix{N,N}, pivot::$pv; check = true) where {N}
4850
L, U, p = _lu(A, pivot, check)
4951
LU(LowerTriangular(L), UpperTriangular(U), p)
5052
end
5153
end
52-
lu(A::StaticMatrix; check = true) = lu(A, Val(true); check=check)
54+
lu(A::StaticLUMatrix; check = true) = lu(A, Val(true); check=check)
5355

5456
# location of the first zero on the diagonal, 0 when not found
55-
function _first_zero_on_diagonal(A::StaticMatrix{M,N,T}) where {M,N,T}
57+
function _first_zero_on_diagonal(A::StaticLUMatrix{M,N,T}) where {M,N,T}
5658
if @generated
5759
quote
5860
$(map(i -> :(A[$i, $i] == zero(T) && return $i), 1:min(M, N))...)
@@ -72,7 +74,7 @@ end
7274

7375
issuccess(F::LU) = _first_zero_on_diagonal(F.U) == 0
7476

75-
@generated function _lu(A::StaticMatrix{M,N,T}, pivot, check) where {M,N,T}
77+
@generated function _lu(A::StaticLUMatrix{M,N,T}, pivot, check) where {M,N,T}
7678
if M*N 14*14
7779
_pivot = if isdefined(LinearAlgebra, :PivotingStrategy) # v1.7 feature
7880
pivot === RowMaximum ? Val(true) : pivot === NoPivot ? Val(false) : pivot()
@@ -125,6 +127,9 @@ __lu(A::StaticMatrix{M,0,T}, ::Val{Pivot}) where {T,M,Pivot} =
125127
__lu(A::StaticMatrix{1,1,T}, ::Val{Pivot}) where {T,Pivot} =
126128
(SMatrix{1,1}(one(T)), A, SVector(1))
127129

130+
__lu(A::LinearAlgebra.HermOrSym{T,<:StaticMatrix{1,1,T}}, ::Val{Pivot}) where {T,Pivot} =
131+
(SMatrix{1,1}(one(T)), A.data, SVector(1))
132+
128133
__lu(A::StaticMatrix{1,N,T}, ::Val{Pivot}) where {N,T,Pivot} =
129134
(SMatrix{1,1,T}(one(T)), A, SVector{1,Int}(1))
130135

@@ -158,7 +163,7 @@ function __lu(A::StaticMatrix{M,1}, ::Val{Pivot}) where {M,Pivot}
158163
return (SMatrix{M,1}(L), U, p)
159164
end
160165

161-
function __lu(A::StaticMatrix{M,N,T}, ::Val{Pivot}) where {M,N,T,Pivot}
166+
function __lu(A::StaticLUMatrix{M,N,T}, ::Val{Pivot}) where {M,N,T,Pivot}
162167
@inbounds begin
163168
kp = 1
164169
if Pivot

test/lu.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,15 @@ using StaticArrays, Test, LinearAlgebra
1111
@test occursin(r"^StaticArrays.LU.*L factor.*U factor"s, sprint(show, MIME("text/plain"), F))
1212
end
1313

14-
@testset "LU decomposition ($m×$n, pivot=$pivot)" for pivot in (true, false), m in [0:4..., 15], n in [0:4..., 15]
15-
a = SMatrix{m,n,Int}(1:(m*n))
14+
@testset "LU decomposition ($m×$n, pivot=$pivot, wrapper=$wrapper)" for pivot in (true, false), m in [0:4..., 15], n in [0:4..., 15], wrapper in [identity, Symmetric, Hermitian]
15+
16+
a = if m == n && m > 0
17+
wrapper(SMatrix{m,n,Int}(1:(m*n)))
18+
elseif wrapper !== identity
19+
continue
20+
else
21+
SMatrix{m,n,Int}(1:(m*n))
22+
end
1623
l, u, p = @inferred(lu(a, Val{pivot}(); check = false))
1724

1825
# expected types
@@ -48,6 +55,7 @@ end
4855
# decomposition is correct
4956
l_u = l*u
5057
@test l*u a[p,:]
58+
5159
end
5260

5361
@testset "LU division ($m×$n)" for m in [1:4..., 15], n in [1:4..., 15]

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,10 @@ if TEST_GROUP ∈ ["", "all", "group-A"]
6969
addtests("expm.jl")
7070
addtests("sqrtm.jl")
7171
addtests("lyap.jl")
72-
addtests("lu.jl")
7372
end
7473

7574
if TEST_GROUP ["", "all", "group-B"]
75+
addtests("lu.jl")
7676
addtests("qr.jl")
7777
addtests("chol.jl") # hermitian_type(::Type{Any}) for block algorithm
7878
addtests("deque.jl")

test/solve.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
using StaticArrays, Test, LinearAlgebra
22

33
@testset "Solving linear system" begin
4-
@testset "Problem size: $n x $n. Matrix type: $m. Element type: $elty" for n in (1,2,3,4,5,8,15),
4+
@testset "Problem size: $n x $n. Matrix type: $m. Element type: $elty, Wrapper: $wrapper" for n in (1,2,3,4,5,8,15),
55
(m, v) in ((SMatrix{n,n}, SVector{n}), (MMatrix{n,n}, MVector{n})),
6-
elty in (Float64, Int)
6+
elty in (Float64, Int), wrapper in (identity, Symmetric, Hermitian)
77

8-
A = elty.(rand(-99:2:99, n, n))
8+
A = wrapper(elty.(rand(-99:2:99, n, n)))
99
b = A * elty.(rand(2:5, n))
10-
@test m(A)\v(b) A\b
10+
@test wrapper(m(A))\v(b) wrapper(A)\b
1111
end
1212

1313
m1 = SMatrix{5,5}(1.0I)

0 commit comments

Comments
 (0)