Skip to content

Commit 79436af

Browse files
committed
Let LU factorization fall back to the Base one for large matrices
If `length(A) > 14*14`, `lufact(::Matrix)` is used to avoid excessive type inference times (at the cost of allocation at run-time).
1 parent 5037bad commit 79436af

File tree

2 files changed

+29
-11
lines changed

2 files changed

+29
-11
lines changed

src/lu.jl

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,28 +11,46 @@ function lu(A::StaticMatrix{N,N}, pivot::Union{Type{Val{false}},Type{Val{true}}}
1111
(LowerTriangular(L), UpperTriangular(U), p)
1212
end
1313

14-
_lu(A::StaticMatrix{0,0,T}, ::Type{Val{Pivot}}) where {T,Pivot} =
14+
@generated function _lu(A::StaticMatrix{M,N,T}, pivot) where {M,N,T}
15+
if M*N 14*14
16+
:(__lu(A, pivot))
17+
else
18+
quote
19+
# call through to Base to avoid excessive time spent on type inference for large matrices
20+
f = lufact(Matrix(A), pivot)
21+
# Trick to get the output eltype - can't rely on the result of f[:L] as
22+
# it's not type inferrable.
23+
T2 = arithmetic_closure(T)
24+
L = similar_type(A, T2, Size($M, $(min(M,N))))(f[:L])
25+
U = similar_type(A, T2, Size($(min(M,N)), $N))(f[:U])
26+
p = similar_type(A, Int, Size($M))(f[:p])
27+
(L,U,p)
28+
end
29+
end
30+
end
31+
32+
__lu(A::StaticMatrix{0,0,T}, ::Type{Val{Pivot}}) where {T,Pivot} =
1533
(SMatrix{0,0,typeof(one(T))}(), A, SVector{0,Int}())
1634

17-
_lu(A::StaticMatrix{0,1,T}, ::Type{Val{Pivot}}) where {T,Pivot} =
35+
__lu(A::StaticMatrix{0,1,T}, ::Type{Val{Pivot}}) where {T,Pivot} =
1836
(SMatrix{0,0,typeof(one(T))}(), A, SVector{0,Int}())
1937

20-
_lu(A::StaticMatrix{0,N,T}, ::Type{Val{Pivot}}) where {T,N,Pivot} =
38+
__lu(A::StaticMatrix{0,N,T}, ::Type{Val{Pivot}}) where {T,N,Pivot} =
2139
(SMatrix{0,0,typeof(one(T))}(), A, SVector{0,Int}())
2240

23-
_lu(A::StaticMatrix{1,0,T}, ::Type{Val{Pivot}}) where {T,Pivot} =
41+
__lu(A::StaticMatrix{1,0,T}, ::Type{Val{Pivot}}) where {T,Pivot} =
2442
(SMatrix{1,0,typeof(one(T))}(), SMatrix{0,0,T}(), SVector{1,Int}(1))
2543

26-
_lu(A::StaticMatrix{M,0,T}, ::Type{Val{Pivot}}) where {T,M,Pivot} =
44+
__lu(A::StaticMatrix{M,0,T}, ::Type{Val{Pivot}}) where {T,M,Pivot} =
2745
(SMatrix{M,0,typeof(one(T))}(), SMatrix{0,0,T}(), SVector{M,Int}(1:M))
2846

29-
_lu(A::StaticMatrix{1,1,T}, ::Type{Val{Pivot}}) where {T,Pivot} =
47+
__lu(A::StaticMatrix{1,1,T}, ::Type{Val{Pivot}}) where {T,Pivot} =
3048
(SMatrix{1,1}(one(T)), A, SVector(1))
3149

32-
_lu(A::StaticMatrix{1,N,T}, ::Type{Val{Pivot}}) where {N,T,Pivot} =
50+
__lu(A::StaticMatrix{1,N,T}, ::Type{Val{Pivot}}) where {N,T,Pivot} =
3351
(SMatrix{1,1,T}(one(T)), A, SVector{1,Int}(1))
3452

35-
function _lu(A::StaticMatrix{M,1}, ::Type{Val{Pivot}}) where {M,Pivot}
53+
function __lu(A::StaticMatrix{M,1}, ::Type{Val{Pivot}}) where {M,Pivot}
3654
@inbounds begin
3755
kp = 1
3856
if Pivot
@@ -62,7 +80,7 @@ function _lu(A::StaticMatrix{M,1}, ::Type{Val{Pivot}}) where {M,Pivot}
6280
return (SMatrix{M,1}(L), U, p)
6381
end
6482

65-
function _lu(A::StaticMatrix{M,N,T}, ::Type{Val{Pivot}}) where {M,N,T,Pivot}
83+
function __lu(A::StaticMatrix{M,N,T}, ::Type{Val{Pivot}}) where {M,N,T,Pivot}
6684
@inbounds begin
6785
kp = 1
6886
if Pivot
@@ -89,7 +107,7 @@ function _lu(A::StaticMatrix{M,N,T}, ::Type{Val{Pivot}}) where {M,N,T,Pivot}
89107

90108
# Update the rest
91109
Arest = A[ps,tailindices(Val{N})] - Ls*Ufirst[:,tailindices(Val{N})]
92-
Lrest, Urest, prest = _lu(Arest, Val{Pivot})
110+
Lrest, Urest, prest = __lu(Arest, Val{Pivot})
93111
p = [SVector{1,Int}(kp); ps[prest]]
94112
L = [[SVector{1}(one(eltype(Ls))); Ls[prest]] [zeros(SMatrix{1}(Lrest[1,:])); Lrest]]
95113
U = [Ufirst; [zeros(Urest[:,1]) Urest]]

test/lu.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using StaticArrays, Base.Test
22

33
@testset "LU decomposition (pivot=$pivot)" for pivot in (true, false)
4-
@testset "$m×$n" for m in 0:4, n in 0:4
4+
@testset "$m×$n" for m in [0:4..., 15, 50], n in [0:4..., 15, 50]
55
a = SMatrix{m,n,Int}(1:(m*n))
66
l, u, p = @inferred(lu(a, Val{pivot}))
77

0 commit comments

Comments
 (0)