@@ -11,28 +11,46 @@ function lu(A::StaticMatrix{N,N}, pivot::Union{Type{Val{false}},Type{Val{true}}}
11
11
(LowerTriangular (L), UpperTriangular (U), p)
12
12
end
13
13
14
- _lu (A:: StaticMatrix{0,0,T} , :: Type{Val{Pivot}} ) where {T,Pivot} =
14
+ @generated function _lu (A:: StaticMatrix{M,N,T} , pivot) where {M,N,T}
15
+ if M* N ≤ 14 * 14
16
+ :(__lu (A, pivot))
17
+ else
18
+ quote
19
+ # call through to Base to avoid excessive time spent on type inference for large matrices
20
+ f = lufact (Matrix (A), pivot)
21
+ # Trick to get the output eltype - can't rely on the result of f[:L] as
22
+ # it's not type inferrable.
23
+ T2 = arithmetic_closure (T)
24
+ L = similar_type (A, T2, Size ($ M, $ (min (M,N))))(f[:L ])
25
+ U = similar_type (A, T2, Size ($ (min (M,N)), $ N))(f[:U ])
26
+ p = similar_type (A, Int, Size ($ M))(f[:p ])
27
+ (L,U,p)
28
+ end
29
+ end
30
+ end
31
+
32
+ __lu (A:: StaticMatrix{0,0,T} , :: Type{Val{Pivot}} ) where {T,Pivot} =
15
33
(SMatrix {0,0,typeof(one(T))} (), A, SVector {0,Int} ())
16
34
17
- _lu (A:: StaticMatrix{0,1,T} , :: Type{Val{Pivot}} ) where {T,Pivot} =
35
+ __lu (A:: StaticMatrix{0,1,T} , :: Type{Val{Pivot}} ) where {T,Pivot} =
18
36
(SMatrix {0,0,typeof(one(T))} (), A, SVector {0,Int} ())
19
37
20
- _lu (A:: StaticMatrix{0,N,T} , :: Type{Val{Pivot}} ) where {T,N,Pivot} =
38
+ __lu (A:: StaticMatrix{0,N,T} , :: Type{Val{Pivot}} ) where {T,N,Pivot} =
21
39
(SMatrix {0,0,typeof(one(T))} (), A, SVector {0,Int} ())
22
40
23
- _lu (A:: StaticMatrix{1,0,T} , :: Type{Val{Pivot}} ) where {T,Pivot} =
41
+ __lu (A:: StaticMatrix{1,0,T} , :: Type{Val{Pivot}} ) where {T,Pivot} =
24
42
(SMatrix {1,0,typeof(one(T))} (), SMatrix {0,0,T} (), SVector {1,Int} (1 ))
25
43
26
- _lu (A:: StaticMatrix{M,0,T} , :: Type{Val{Pivot}} ) where {T,M,Pivot} =
44
+ __lu (A:: StaticMatrix{M,0,T} , :: Type{Val{Pivot}} ) where {T,M,Pivot} =
27
45
(SMatrix {M,0,typeof(one(T))} (), SMatrix {0,0,T} (), SVector {M,Int} (1 : M))
28
46
29
- _lu (A:: StaticMatrix{1,1,T} , :: Type{Val{Pivot}} ) where {T,Pivot} =
47
+ __lu (A:: StaticMatrix{1,1,T} , :: Type{Val{Pivot}} ) where {T,Pivot} =
30
48
(SMatrix {1,1} (one (T)), A, SVector (1 ))
31
49
32
- _lu (A:: StaticMatrix{1,N,T} , :: Type{Val{Pivot}} ) where {N,T,Pivot} =
50
+ __lu (A:: StaticMatrix{1,N,T} , :: Type{Val{Pivot}} ) where {N,T,Pivot} =
33
51
(SMatrix {1,1,T} (one (T)), A, SVector {1,Int} (1 ))
34
52
35
- function _lu (A:: StaticMatrix{M,1} , :: Type{Val{Pivot}} ) where {M,Pivot}
53
+ function __lu (A:: StaticMatrix{M,1} , :: Type{Val{Pivot}} ) where {M,Pivot}
36
54
@inbounds begin
37
55
kp = 1
38
56
if Pivot
@@ -62,7 +80,7 @@ function _lu(A::StaticMatrix{M,1}, ::Type{Val{Pivot}}) where {M,Pivot}
62
80
return (SMatrix {M,1} (L), U, p)
63
81
end
64
82
65
- function _lu (A:: StaticMatrix{M,N,T} , :: Type{Val{Pivot}} ) where {M,N,T,Pivot}
83
+ function __lu (A:: StaticMatrix{M,N,T} , :: Type{Val{Pivot}} ) where {M,N,T,Pivot}
66
84
@inbounds begin
67
85
kp = 1
68
86
if Pivot
@@ -89,7 +107,7 @@ function _lu(A::StaticMatrix{M,N,T}, ::Type{Val{Pivot}}) where {M,N,T,Pivot}
89
107
90
108
# Update the rest
91
109
Arest = A[ps,tailindices (Val{N})] - Ls* Ufirst[:,tailindices (Val{N})]
92
- Lrest, Urest, prest = _lu (Arest, Val{Pivot})
110
+ Lrest, Urest, prest = __lu (Arest, Val{Pivot})
93
111
p = [SVector {1,Int} (kp); ps[prest]]
94
112
L = [[SVector {1} (one (eltype (Ls))); Ls[prest]] [zeros (SMatrix {1} (Lrest[1 ,:])); Lrest]]
95
113
U = [Ufirst; [zeros (Urest[:,1 ]) Urest]]
0 commit comments