Skip to content

Commit e4e3f39

Browse files
committed
LU fixes
1 parent aaa99e9 commit e4e3f39

File tree

5 files changed

+26
-15
lines changed

5 files changed

+26
-15
lines changed

src/det.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ end
4747
if prod(S) 14*14
4848
quote
4949
@_inline_meta
50-
L, U, p = lu(A)
51-
det(U)*_parity(p)
50+
LUp = lu(A)
51+
det(LUp.U)*_parity(LUp.p)
5252
end
5353
else
5454
:(@_inline_meta; det(Matrix(A)))
@@ -62,9 +62,9 @@ end
6262
if prod(S) 14*14
6363
quote
6464
@_inline_meta
65-
L, U, p = lu(A)
66-
d, s = logabsdet(U)
67-
d + log(s*_parity(p))
65+
LUp = lu(A)
66+
d, s = logabsdet(LUp.U)
67+
d + log(s*_parity(LUp.p))
6868
end
6969
else
7070
:(@_inline_meta; logdet(drop_sdims(A)))

src/inv.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ end
6060
if prod(S) 14*14
6161
quote
6262
@_inline_meta
63-
L, U, p = lu(A)
64-
U \ (L \ eye(A)[p,:])
63+
LUp = lu(A)
64+
LUp.U \ (LUp.L \ eye(A)[LUp.p,:])
6565
end
6666
else
6767
:(@_inline_meta; similar_type(A)(inv(Matrix(A))))

src/lu.jl

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,27 @@
1+
# define our own LU type, since LinearAlgebra.LU requires p::Vector
2+
struct LU{L,U,p}
3+
L::L
4+
U::U
5+
p::p
6+
end
7+
8+
# iteration for destructuring into components
9+
Base.iterate(S::LU) = (S.L, Val(:U))
10+
Base.iterate(S::LU, ::Val{:U}) = (S.U, Val(:p))
11+
Base.iterate(S::LU, ::Val{:p}) = (S.p, Val(:done))
12+
Base.iterate(S::LU, ::Val{:done}) = nothing
13+
114
# LU decomposition
215
function lu(A::StaticMatrix, pivot::Union{Val{false},Val{true}}=Val(true))
3-
L,U,p = _lu(A, pivot)
4-
(L,U,p)
16+
L, U, p = _lu(A, pivot)
17+
LU(L, U, p)
518
end
619

720
# For the square version, return explicit lower and upper triangular matrices.
821
# We would do this for the rectangular case too, but Base doesn't support that.
922
function lu(A::StaticMatrix{N,N}, pivot::Union{Val{false},Val{true}}=Val(true)) where {N}
10-
L,U,p = _lu(A, pivot)
11-
(LowerTriangular(L), UpperTriangular(U), p)
23+
L, U, p = _lu(A, pivot)
24+
LU(LowerTriangular(L), UpperTriangular(U), p)
1225
end
1326

1427
@generated function _lu(A::StaticMatrix{M,N,T}, pivot) where {M,N,T}

src/solve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ end
3434
if prod(Sa) 14*14
3535
quote
3636
@_inline_meta
37-
L, U, p = lu(a)
38-
U \ (L \ $(length(Sb) > 1 ? :(b[p,:]) : :(b[p])))
37+
LUp = lu(a)
38+
LUp.U \ (LUp.L \ $(length(Sb) > 1 ? :(b[LUp.p,:]) : :(b[LUp.p])))
3939
end
4040
else
4141
quote

test/lu.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
using StaticArrays, Test, LinearAlgebra
22

3-
@testset "LU decomposition" begin
43
@testset "LU decomposition ($m×$n, pivot=$pivot)" for pivot in (true, false), m in [0:4..., 15], n in [0:4..., 15]
54
a = SMatrix{m,n,Int}(1:(m*n))
65
l, u, p = @inferred(lu(a, Val{pivot}()))
@@ -38,4 +37,3 @@ using StaticArrays, Test, LinearAlgebra
3837
# decomposition is correct
3938
@test l*u a[p,:]
4039
end
41-
end

0 commit comments

Comments
 (0)