Skip to content

Commit 632dde0

Browse files
chriselrodmateuszbaranc42f
authored
Generalize unrolled Cholesky decomposition (#817)
* Unrolled Cholesky factorization for any (moderate) size * Ensure the matrix type is preserved for large-sized fallback. Co-authored-by: Mateusz Baran <mateuszbaran89@gmail.com> Co-authored-by: Chris Foster <chris42f@gmail.com>
1 parent 7dcbccf commit 632dde0

File tree

2 files changed

+81
-73
lines changed

2 files changed

+81
-73
lines changed

src/cholesky.jl

Lines changed: 36 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,48 +12,46 @@ end
1212
end
1313
@inline LinearAlgebra._chol!(A::StaticMatrix, ::Type{UpperTriangular}) = (cholesky(A).U, 0)
1414

15-
16-
@generated function _cholesky(::Size{(1,1)}, A::StaticMatrix)
17-
@assert size(A) == (1,1)
18-
19-
quote
20-
$(Expr(:meta, :inline))
21-
T = promote_type(typeof(sqrt(one(eltype(A)))), Float32)
22-
similar_type(A,T)((sqrt(A[1]), ))
15+
@generated function _cholesky(::Size{S}, A::StaticMatrix{M,M}) where {S,M}
16+
@assert (M,M) == S
17+
M > 24 && return :(_cholesky_large(Size{$S}(), A))
18+
q = Expr(:block)
19+
for n 1:M
20+
for m n:M
21+
L_m_n = Symbol(:L_,m,:_,n)
22+
push!(q.args, :($L_m_n = @inbounds A[$n, $m]))
23+
end
24+
for k 1:n-1, m n:M
25+
L_m_n = Symbol(:L_,m,:_,n)
26+
L_m_k = Symbol(:L_,m,:_,k)
27+
L_n_k = Symbol(:L_,n,:_,k)
28+
push!(q.args, :($L_m_n = muladd(-$L_m_k, $L_n_k', $L_m_n)))
29+
end
30+
L_n_n = Symbol(:L_,n,:_,n)
31+
push!(q.args, :($L_n_n = sqrt($L_n_n)))
32+
Linv_n_n = Symbol(:Linv_,n,:_,n)
33+
push!(q.args, :($Linv_n_n = inv($L_n_n)))
34+
for m n+1:M
35+
L_m_n = Symbol(:L_,m,:_,n)
36+
push!(q.args, :($L_m_n *= $Linv_n_n))
37+
end
2338
end
24-
end
25-
26-
@generated function _cholesky(::Size{(2,2)}, A::StaticMatrix)
27-
@assert size(A) == (2,2)
28-
29-
quote
30-
$(Expr(:meta, :inline))
31-
@inbounds a = sqrt(A[1])
32-
@inbounds b = A[3] / a
33-
@inbounds c = sqrt(A[4] - b'*b)
34-
T = promote_type(typeof(sqrt(one(eltype(A)))), Float32)
35-
similar_type(A,T)((a, zero(T), b, c))
36-
end
37-
end
38-
39-
@generated function _cholesky(::Size{(3,3)}, A::StaticMatrix)
40-
@assert size(A) == (3,3)
41-
42-
quote
43-
$(Expr(:meta, :inline))
44-
@inbounds a11 = sqrt(A[1])
45-
@inbounds a12 = A[4] / a11
46-
@inbounds a22 = sqrt(A[5] - a12'*a12)
47-
@inbounds a13 = A[7] / a11
48-
@inbounds a23 = (A[8] - a12'*a13) / a22
49-
@inbounds a33 = sqrt(A[9] - a13'*a13 - a23'*a23)
50-
T = promote_type(typeof(sqrt(one(eltype(A)))), Float32)
51-
similar_type(A,T)((a11, zero(T), zero(T), a12, a22, zero(T), a13, a23, a33))
39+
push!(q.args, :(T = promote_type(typeof(sqrt(one(eltype(A)))), Float32)))
40+
ret = Expr(:tuple)
41+
for n 1:M
42+
for m 1:n
43+
push!(ret.args, Symbol(:L_,n,:_,m))
44+
end
45+
for m n+1:M
46+
push!(ret.args, :(zero(T)))
47+
end
5248
end
49+
push!(q.args, :(similar_type(A, T)($ret)))
50+
return Expr(:block, Expr(:meta, :inline), q)
5351
end
5452

5553
# Otherwise default algorithm returning wrapped SizedArray
56-
@inline _cholesky(::Size{S}, A::StaticArray) where {S} =
57-
SizedArray{Tuple{S...}}(Matrix(cholesky(Hermitian(Matrix(A))).U))
54+
@inline _cholesky_large(::Size{S}, A::StaticArray) where {S} =
55+
similar_type(A)(cholesky(Hermitian(Matrix(A))).U)
5856

5957
LinearAlgebra.hermitian_type(::Type{SA}) where {T, S, SA<:SArray{S,T}} = Hermitian{T,SA}

test/chol.jl

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,44 +2,54 @@ using StaticArrays, Test, LinearAlgebra
22
using LinearAlgebra: PosDefException
33

44
@testset "Cholesky decomposition" begin
5-
@testset "1×1" begin
6-
m = @SMatrix [4.0]
7-
(c,) = cholesky(m).U
8-
@test c === 2.0
9-
end
5+
for elty in [Float32, Float64, ComplexF64]
6+
@testset "1×1" begin
7+
m = @SMatrix [4.0]
8+
(c,) = cholesky(m).U
9+
@test c === 2.0
10+
end
1011

11-
@testset "2×2" for i = 1:100
12-
m_a = randn(2,2)
13-
#non hermitian
14-
@test_throws PosDefException cholesky(SMatrix{2,2}(m_a))
15-
m_a = m_a*m_a'
16-
m = SMatrix{2,2}(m_a)
17-
@test cholesky(Hermitian(m)).U cholesky(m_a).U
18-
@test cholesky(Hermitian(m)).L cholesky(m_a).L
19-
end
12+
@testset "2×2" for i = 1:100
13+
m_a = randn(elty, 2,2)
14+
#non hermitian
15+
@test_throws PosDefException cholesky(SMatrix{2,2}(m_a))
16+
m_a = m_a*m_a'
17+
m = SMatrix{2,2}(m_a)
18+
@test cholesky(Hermitian(m)).U cholesky(m_a).U
19+
@test cholesky(Hermitian(m)).L cholesky(m_a).L
20+
end
2021

21-
@testset "3×3" for i = 1:100
22-
m_a = randn(3,3)
23-
#non hermitian
24-
@test_throws PosDefException cholesky(SMatrix{3,3}(m_a))
25-
m_a = m_a*m_a'
26-
m = SMatrix{3,3}(m_a)
27-
@test cholesky(m).U cholesky(m_a).U
28-
@test cholesky(m).L cholesky(m_a).L
29-
@test cholesky(Hermitian(m)).U cholesky(m_a).U
30-
@test cholesky(Hermitian(m)).L cholesky(m_a).L
31-
end
32-
@testset "4×4" for i = 1:100
33-
m_a = randn(4,4)
34-
#non hermitian
35-
@test_throws PosDefException cholesky(SMatrix{4,4}(m_a))
36-
m_a = m_a*m_a'
37-
m = SMatrix{4,4}(m_a)
38-
@test cholesky(m).L cholesky(m_a).L
39-
@test cholesky(m).U cholesky(m_a).U
40-
@test cholesky(Hermitian(m)).L cholesky(m_a).L
41-
@test cholesky(Hermitian(m)).U cholesky(m_a).U
22+
@testset "3×3" for i = 1:100
23+
m_a = randn(elty, 3,3)
24+
#non hermitian
25+
@test_throws PosDefException cholesky(SMatrix{3,3}(m_a))
26+
m_a = m_a*m_a'
27+
m = SMatrix{3,3}(m_a)
28+
@test cholesky(m).U cholesky(m_a).U
29+
@test cholesky(m).L cholesky(m_a).L
30+
@test cholesky(Hermitian(m)).U cholesky(m_a).U
31+
@test cholesky(Hermitian(m)).L cholesky(m_a).L
32+
end
33+
@testset "4×4" for i = 1:100
34+
m_a = randn(elty, 4,4)
35+
#non hermitian
36+
@test_throws PosDefException cholesky(SMatrix{4,4}(m_a))
37+
m_a = m_a*m_a'
38+
m = SMatrix{4,4}(m_a)
39+
@test cholesky(m).L cholesky(m_a).L
40+
@test cholesky(m).U cholesky(m_a).U
41+
@test cholesky(Hermitian(m)).L cholesky(m_a).L
42+
@test cholesky(Hermitian(m)).U cholesky(m_a).U
43+
end
44+
45+
@testset "large (25x25)" begin
46+
m_a = randn(elty, 25, 25)
47+
m_a = m_a*m_a'
48+
m = SMatrix{25,25}(m_a)
49+
@test cholesky(m).L cholesky(m_a).L
50+
end
4251
end
52+
4353
@testset "static blockmatrix" for i = 1:10
4454
m_a = randn(3,3)
4555
m_a = m_a*m_a'

0 commit comments

Comments
 (0)