Skip to content

Commit 9e869e9

Browse files
committed
Non-allocating LU decomposition
1 parent 99181cf commit 9e869e9

File tree

1 file changed

+90
-14
lines changed

1 file changed

+90
-14
lines changed

src/lu.jl

Lines changed: 90 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,105 @@
11
# LU decomposition
22
function lu(A::StaticMatrix, pivot::Union{Type{Val{false}},Type{Val{true}}}=Val{true})
3-
L,U,p = _lu(Size(A), A, pivot)
3+
L,U,p = _lu(A, pivot)
44
(L,U,p)
55
end
66

77
# For the square version, return explicit lower and upper triangular matrices.
88
# We would do this for the rectangular case too, but Base doesn't support that.
99
function lu(A::StaticMatrix{N,N}, pivot::Union{Type{Val{false}},Type{Val{true}}}=Val{true}) where {N}
10-
L,U,p = _lu(Size(A), A, pivot)
10+
L,U,p = _lu(A, pivot)
1111
(LowerTriangular(L), UpperTriangular(U), p)
1212
end
1313

14+
_lu(A::StaticMatrix{0,0,T}, ::Type{Val{Pivot}}) where {T,Pivot} =
15+
(SMatrix{0,0,typeof(one(T))}(), A, SVector{0,Int}())
1416

15-
@inline function _lu(::Size{S}, A::StaticMatrix, pivot) where {S}
16-
# For now, just call through to Base.
17-
# TODO: statically sized LU without allocations!
18-
f = lufact(Matrix(A), pivot)
19-
T = eltype(A)
20-
# Trick to get the output eltype - can't rely on the result of f[:L] as
21-
# it's not type inferrable.
22-
T2 = arithmetic_closure(T)
23-
L = similar_type(A, T2, Size(Size(A)[1], diagsize(A)))(f[:L])
24-
U = similar_type(A, T2, Size(diagsize(A), Size(A)[2]))(f[:U])
25-
p = similar_type(A, Int, Size(Size(A)[1]))(f[:p])
26-
(L,U,p)
17+
_lu(A::StaticMatrix{0,1,T}, ::Type{Val{Pivot}}) where {T,Pivot} =
18+
(SMatrix{0,0,typeof(one(T))}(), A, SVector{0,Int}())
19+
20+
_lu(A::StaticMatrix{0,N,T}, ::Type{Val{Pivot}}) where {T,N,Pivot} =
21+
(SMatrix{0,0,typeof(one(T))}(), A, SVector{0,Int}())
22+
23+
_lu(A::StaticMatrix{1,0,T}, ::Type{Val{Pivot}}) where {T,Pivot} =
24+
(SMatrix{1,0,typeof(one(T))}(), SMatrix{0,0,T}(), SVector{1,Int}(1))
25+
26+
_lu(A::StaticMatrix{M,0,T}, ::Type{Val{Pivot}}) where {T,M,Pivot} =
27+
(SMatrix{M,0,typeof(one(T))}(), SMatrix{0,0,T}(), SVector{M,Int}(1:M))
28+
29+
_lu(A::StaticMatrix{1,1,T}, ::Type{Val{Pivot}}) where {T,Pivot} =
30+
(SMatrix{1,1}(one(T)), A, SVector(1))
31+
32+
_lu(A::StaticMatrix{1,N,T}, ::Type{Val{Pivot}}) where {N,T,Pivot} =
33+
(SMatrix{1,1,T}(one(T)), A, SVector{1,Int}(1))
34+
35+
function _lu(A::StaticMatrix{M,1}, ::Type{Val{Pivot}}) where {M,Pivot}
36+
kp = 1
37+
if Pivot
38+
amax = abs(A[1,1])
39+
for i = 2:M
40+
absi = abs(A[i,1])
41+
if absi > amax
42+
kp = i
43+
amax = absi
44+
end
45+
end
46+
end
47+
ps = tailindices(Val{M})
48+
if kp != 1
49+
ps = setindex(ps, 1, kp-1)
50+
end
51+
U = SMatrix{1,1}(A[kp,1])
52+
# Scale first column
53+
Akkinv = inv(A[kp,1])
54+
Ls = A[ps,1] * Akkinv
55+
if !isfinite(Akkinv)
56+
Ls = zeros(Ls)
57+
end
58+
L = [SVector{1}(one(eltype(Ls))); Ls]
59+
p = [SVector{1,Int}(kp); ps]
60+
return (SMatrix{M,1}(L), U, p)
61+
end
62+
63+
function _lu(A::StaticMatrix{M,N,T}, ::Type{Val{Pivot}}) where {M,N,T,Pivot}
64+
kp = 1
65+
if Pivot
66+
amax = abs(A[1,1])
67+
for i = 2:M
68+
absi = abs(A[i,1])
69+
if absi > amax
70+
kp = i
71+
amax = absi
72+
end
73+
end
74+
end
75+
ps = tailindices(Val{M})
76+
if kp != 1
77+
ps = setindex(ps, 1, kp-1)
78+
end
79+
Ufirst = SMatrix{1,N}(A[kp,:])
80+
# Scale first column
81+
Akkinv = inv(A[kp,1])
82+
Ls = A[ps,1] * Akkinv
83+
if !isfinite(Akkinv)
84+
Ls = zeros(Ls)
85+
end
86+
87+
# Update the rest
88+
Arest = A[ps,tailindices(Val{N})] - Ls*Ufirst[:,tailindices(Val{N})]
89+
Lrest, Urest, prest = _lu(Arest, Val{Pivot})
90+
p = [SVector{1,Int}(kp); ps[prest]]
91+
L = [[SVector{1}(one(eltype(Ls))); Ls[prest]] [zeros(SMatrix{1}(Lrest[1,:])); Lrest]]
92+
U = [Ufirst; [zeros(Urest[:,1]) Urest]]
93+
94+
return (L, U, p)
95+
end
96+
97+
# Create SVector(2,3,...,M)
98+
# Note that
99+
# tailindices(::Type{Val{M}}) where {M} = SVector(Base.tail(ntuple(identity, Val{M})))
100+
# works, too, but is only inferrable for M ≤ 14 (at least up to Julia 0.7.0-DEV.4021)
101+
@generated function tailindices(::Type{Val{M}}) where {M}
102+
:(SVector{$(M-1),Int}($(tuple(2:M...))))
27103
end
28104

29105
# Base.lufact() interface is fairly inherently type unstable. Punt on

0 commit comments

Comments
 (0)