Skip to content

Commit f9e6577

Browse files
authored
Merge pull request #368 from martinholters/mh/lu
Non-allocating LU decomposition
2 parents 0906515 + 79436af commit f9e6577

File tree

2 files changed

+147
-30
lines changed

2 files changed

+147
-30
lines changed

src/lu.jl

Lines changed: 111 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,126 @@
11
# LU decomposition
22
function lu(A::StaticMatrix, pivot::Union{Type{Val{false}},Type{Val{true}}}=Val{true})
3-
L,U,p = _lu(Size(A), A, pivot)
3+
L,U,p = _lu(A, pivot)
44
(L,U,p)
55
end
66

77
# For the square version, return explicit lower and upper triangular matrices.
88
# We would do this for the rectangular case too, but Base doesn't support that.
99
function lu(A::StaticMatrix{N,N}, pivot::Union{Type{Val{false}},Type{Val{true}}}=Val{true}) where {N}
10-
L,U,p = _lu(Size(A), A, pivot)
10+
L,U,p = _lu(A, pivot)
1111
(LowerTriangular(L), UpperTriangular(U), p)
1212
end
1313

14+
@generated function _lu(A::StaticMatrix{M,N,T}, pivot) where {M,N,T}
15+
if M*N 14*14
16+
:(__lu(A, pivot))
17+
else
18+
quote
19+
# call through to Base to avoid excessive time spent on type inference for large matrices
20+
f = lufact(Matrix(A), pivot)
21+
# Trick to get the output eltype - can't rely on the result of f[:L] as
22+
# it's not type inferrable.
23+
T2 = arithmetic_closure(T)
24+
L = similar_type(A, T2, Size($M, $(min(M,N))))(f[:L])
25+
U = similar_type(A, T2, Size($(min(M,N)), $N))(f[:U])
26+
p = similar_type(A, Int, Size($M))(f[:p])
27+
(L,U,p)
28+
end
29+
end
30+
end
1431

15-
@inline function _lu(::Size{S}, A::StaticMatrix, pivot) where {S}
16-
# For now, just call through to Base.
17-
# TODO: statically sized LU without allocations!
18-
f = lufact(Matrix(A), pivot)
19-
T = eltype(A)
20-
# Trick to get the output eltype - can't rely on the result of f[:L] as
21-
# it's not type inferrable.
22-
T2 = arithmetic_closure(T)
23-
L = similar_type(A, T2, Size(Size(A)[1], diagsize(A)))(f[:L])
24-
U = similar_type(A, T2, Size(diagsize(A), Size(A)[2]))(f[:U])
25-
p = similar_type(A, Int, Size(Size(A)[1]))(f[:p])
26-
(L,U,p)
32+
__lu(A::StaticMatrix{0,0,T}, ::Type{Val{Pivot}}) where {T,Pivot} =
33+
(SMatrix{0,0,typeof(one(T))}(), A, SVector{0,Int}())
34+
35+
__lu(A::StaticMatrix{0,1,T}, ::Type{Val{Pivot}}) where {T,Pivot} =
36+
(SMatrix{0,0,typeof(one(T))}(), A, SVector{0,Int}())
37+
38+
__lu(A::StaticMatrix{0,N,T}, ::Type{Val{Pivot}}) where {T,N,Pivot} =
39+
(SMatrix{0,0,typeof(one(T))}(), A, SVector{0,Int}())
40+
41+
__lu(A::StaticMatrix{1,0,T}, ::Type{Val{Pivot}}) where {T,Pivot} =
42+
(SMatrix{1,0,typeof(one(T))}(), SMatrix{0,0,T}(), SVector{1,Int}(1))
43+
44+
__lu(A::StaticMatrix{M,0,T}, ::Type{Val{Pivot}}) where {T,M,Pivot} =
45+
(SMatrix{M,0,typeof(one(T))}(), SMatrix{0,0,T}(), SVector{M,Int}(1:M))
46+
47+
__lu(A::StaticMatrix{1,1,T}, ::Type{Val{Pivot}}) where {T,Pivot} =
48+
(SMatrix{1,1}(one(T)), A, SVector(1))
49+
50+
__lu(A::StaticMatrix{1,N,T}, ::Type{Val{Pivot}}) where {N,T,Pivot} =
51+
(SMatrix{1,1,T}(one(T)), A, SVector{1,Int}(1))
52+
53+
function __lu(A::StaticMatrix{M,1}, ::Type{Val{Pivot}}) where {M,Pivot}
54+
@inbounds begin
55+
kp = 1
56+
if Pivot
57+
amax = abs(A[1,1])
58+
for i = 2:M
59+
absi = abs(A[i,1])
60+
if absi > amax
61+
kp = i
62+
amax = absi
63+
end
64+
end
65+
end
66+
ps = tailindices(Val{M})
67+
if kp != 1
68+
ps = setindex(ps, 1, kp-1)
69+
end
70+
U = SMatrix{1,1}(A[kp,1])
71+
# Scale first column
72+
Akkinv = inv(A[kp,1])
73+
Ls = A[ps,1] * Akkinv
74+
if !isfinite(Akkinv)
75+
Ls = zeros(Ls)
76+
end
77+
L = [SVector{1}(one(eltype(Ls))); Ls]
78+
p = [SVector{1,Int}(kp); ps]
79+
end
80+
return (SMatrix{M,1}(L), U, p)
81+
end
82+
83+
function __lu(A::StaticMatrix{M,N,T}, ::Type{Val{Pivot}}) where {M,N,T,Pivot}
84+
@inbounds begin
85+
kp = 1
86+
if Pivot
87+
amax = abs(A[1,1])
88+
for i = 2:M
89+
absi = abs(A[i,1])
90+
if absi > amax
91+
kp = i
92+
amax = absi
93+
end
94+
end
95+
end
96+
ps = tailindices(Val{M})
97+
if kp != 1
98+
ps = setindex(ps, 1, kp-1)
99+
end
100+
Ufirst = SMatrix{1,N}(A[kp,:])
101+
# Scale first column
102+
Akkinv = inv(A[kp,1])
103+
Ls = A[ps,1] * Akkinv
104+
if !isfinite(Akkinv)
105+
Ls = zeros(Ls)
106+
end
107+
108+
# Update the rest
109+
Arest = A[ps,tailindices(Val{N})] - Ls*Ufirst[:,tailindices(Val{N})]
110+
Lrest, Urest, prest = __lu(Arest, Val{Pivot})
111+
p = [SVector{1,Int}(kp); ps[prest]]
112+
L = [[SVector{1}(one(eltype(Ls))); Ls[prest]] [zeros(SMatrix{1}(Lrest[1,:])); Lrest]]
113+
U = [Ufirst; [zeros(Urest[:,1]) Urest]]
114+
end
115+
return (L, U, p)
116+
end
117+
118+
# Create SVector(2,3,...,M)
119+
# Note that
120+
# tailindices(::Type{Val{M}}) where {M} = SVector(Base.tail(ntuple(identity, Val{M})))
121+
# works, too, but is only inferrable for M ≤ 14 (at least up to Julia 0.7.0-DEV.4021)
122+
@generated function tailindices(::Type{Val{M}}) where {M}
123+
:(SVector{$(M-1),Int}($(tuple(2:M...))))
27124
end
28125

29126
# Base.lufact() interface is fairly inherently type unstable. Punt on

test/lu.jl

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,41 @@
11
using StaticArrays, Base.Test
22

3-
@testset "LU decomposition" begin
4-
# Square case
5-
m22 = @SMatrix [1 2; 3 4]
6-
@test @inferred(lu(m22)) isa Tuple{LowerTriangular{Float64,SMatrix{2,2,Float64,4}}, UpperTriangular{Float64,SMatrix{2,2,Float64,4}}, SVector{2,Int}}
7-
@test lu(m22)[1]::LowerTriangular{<:Any,<:StaticMatrix} lu(Matrix(m22))[1]
8-
@test lu(m22)[2]::UpperTriangular{<:Any,<:StaticMatrix} lu(Matrix(m22))[2]
9-
@test lu(m22)[3]::StaticVector lu(Matrix(m22))[3]
3+
@testset "LU decomposition (pivot=$pivot)" for pivot in (true, false)
4+
@testset "$m×$n" for m in [0:4..., 15, 50], n in [0:4..., 15, 50]
5+
a = SMatrix{m,n,Int}(1:(m*n))
6+
l, u, p = @inferred(lu(a, Val{pivot}))
107

11-
# Rectangular case
12-
m23 = @SMatrix Float64[3 9 4; 6 6 2]
13-
@test @inferred(lu(m23)) isa Tuple{SMatrix{2,2,Float64,4}, SMatrix{2,3,Float64,6}, SVector{2,Int}}
14-
@test lu(m23)[1] lu(Matrix(m23))[1]
15-
@test lu(m23)[2] lu(Matrix(m23))[2]
16-
@test lu(m23)[3] lu(Matrix(m23))[3]
8+
# expected types
9+
@test p isa SVector{m,Int}
10+
if m==n
11+
@test l isa LowerTriangular{<:Any,<:SMatrix{m,n}}
12+
@test u isa UpperTriangular{<:Any,<:SMatrix{m,n}}
13+
else
14+
@test l isa SMatrix{m,min(m,n)}
15+
@test u isa SMatrix{min(m,n),n}
16+
end
1717

18-
@test lu(m23')[1] lu(Matrix(m23'))[1]
19-
@test lu(m23')[2] lu(Matrix(m23'))[2]
20-
@test lu(m23')[3] lu(Matrix(m23'))[3]
18+
if pivot
19+
# p is a permutation
20+
@test sort(p) == collect(1:m)
21+
else
22+
@test p == collect(1:m)
23+
end
24+
25+
# l is unit lower triangular
26+
for i=1:m, j=(i+1):size(l,2)
27+
@test iszero(l[i,j])
28+
end
29+
for i=1:size(l,2)
30+
@test l[i,i] == 1
31+
end
32+
33+
# u is upper triangular
34+
for i=1:size(u,1), j=1:i-1
35+
@test iszero(u[i,j])
36+
end
37+
38+
# decomposition is correct
39+
@test l*u a[p,:]
40+
end
2141
end

0 commit comments

Comments
 (0)