Skip to content

Commit 80443b1

Browse files
Revamp Cholesky implementation (#311)
* Revamp Cholesky implementation * Bump patch * Tidy up + fix syntax * Reinstate tests * Update src/rulesets/LinearAlgebra/factorization.jl Co-authored-by: Seth Axen <seth.axen@gmail.com> * Update src/rulesets/LinearAlgebra/factorization.jl Co-authored-by: Seth Axen <seth.axen@gmail.com> * Conjugate trsm! Co-authored-by: Seth Axen <seth.axen@gmail.com> * Rename to DeltaC and use BlasFloat * Conjugate in both places * Change diag scaling for Complex Co-authored-by: Seth Axen <seth.axen@gmail.com> * Fix naming * Use HermOrSym and generalise Diagonal chol to complex * Revert to Real * Fix naming * Test cholesky(::Real) * Tidy up Diagonal cholesky * Restrict to BlasReal * Refactor Diagonal pullback slightly * Reinstate all tests * Update src/rulesets/LinearAlgebra/factorization.jl Co-authored-by: Seth Axen <seth.axen@gmail.com> * Fix up broken tests * Address code review queries Co-authored-by: Seth Axen <seth.axen@gmail.com>
1 parent 2dab3ab commit 80443b1

File tree

3 files changed

+117
-230
lines changed

3 files changed

+117
-230
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.7.36"
3+
version = "0.7.37"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/LinearAlgebra/factorization.jl

Lines changed: 60 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -70,20 +70,69 @@ end
7070
##### `cholesky`
7171
#####
7272

73-
function rrule(::typeof(cholesky), X::AbstractMatrix{<:Real})
74-
F = cholesky(X)
75-
function cholesky_pullback::Composite)
76-
∂X = if F.uplo === 'U'
77-
chol_blocked_rev.U, F.U, 25, true)
78-
else
79-
chol_blocked_rev.L, F.L, 25, false)
80-
end
81-
return (NO_FIELDS, ∂X)
73+
function rrule(::typeof(cholesky), A::Real, uplo::Symbol=:U)
74+
C = cholesky(A, uplo)
75+
function cholesky_pullback(ΔC::Composite)
76+
return NO_FIELDS, ΔC.factors[1, 1] / (2 * C.U[1, 1]), DoesNotExist()
8277
end
83-
return F, cholesky_pullback
78+
return C, cholesky_pullback
79+
end
80+
81+
function rrule(::typeof(cholesky), A::Diagonal{<:Real}, ::Val{false}; check::Bool=true)
82+
C = cholesky(A, Val(false); check=check)
83+
function cholesky_pullback(ΔC::Composite)
84+
= Diagonal(diag(ΔC.factors) .* inv.(2 .* C.factors.diag))
85+
return NO_FIELDS, Ā, DoesNotExist()
86+
end
87+
return C, cholesky_pullback
88+
end
89+
90+
# The appropriate cotangent is different depending upon whether A is Symmetric / Hermitian,
91+
# or just a StridedMatrix.
92+
# Implementation due to Seeger, Matthias, et al. "Auto-differentiating linear algebra."
93+
function rrule(
94+
::typeof(cholesky),
95+
A::LinearAlgebra.HermOrSym{<:LinearAlgebra.BlasReal, <:StridedMatrix},
96+
::Val{false};
97+
check::Bool=true,
98+
)
99+
C = cholesky(A, Val(false); check=check)
100+
function cholesky_pullback(ΔC::Composite)
101+
Ā, U = _cholesky_pullback_shared_code(C, ΔC)
102+
= BLAS.trsm!('R', 'U', 'C', 'N', one(eltype(Ā)) / 2, U.data, Ā)
103+
return NO_FIELDS, _symhermtype(A)(Ā), DoesNotExist()
104+
end
105+
return C, cholesky_pullback
106+
end
107+
108+
function rrule(
109+
::typeof(cholesky),
110+
A::StridedMatrix{<:LinearAlgebra.BlasReal},
111+
::Val{false};
112+
check::Bool=true,
113+
)
114+
C = cholesky(A, Val(false); check=check)
115+
function cholesky_pullback(ΔC::Composite)
116+
Ā, U = _cholesky_pullback_shared_code(C, ΔC)
117+
= BLAS.trsm!('R', 'U', 'C', 'N', one(eltype(Ā)), U.data, Ā)
118+
idx = diagind(Ā)
119+
@views Ā[idx] .= real.(Ā[idx]) ./ 2
120+
return (NO_FIELDS, UpperTriangular(Ā), DoesNotExist())
121+
end
122+
return C, cholesky_pullback
123+
end
124+
125+
function _cholesky_pullback_shared_code(C, ΔC)
126+
U = C.U
127+
= ΔC.U
128+
= similar(U.data)
129+
= mul!(Ā, Ū, U')
130+
= LinearAlgebra.copytri!(Ā, 'U', true)
131+
= ldiv!(U, Ā)
132+
return Ā, U
84133
end
85134

86-
function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: Cholesky
135+
function rrule(::typeof(getproperty), F::T, x::Symbol) where {T <: Cholesky}
87136
function getproperty_cholesky_pullback(Ȳ)
88137
C = Composite{T}
89138
∂F = if x === :U
@@ -103,161 +152,3 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: Cholesky
103152
end
104153
return getproperty(F, x), getproperty_cholesky_pullback
105154
end
106-
107-
# See "Differentiation of the Cholesky decomposition" (Murray 2016), pages 5-9 in particular,
108-
# for derivations. Here we're implementing the algorithms and their transposes.
109-
110-
"""
111-
level2partition(A::AbstractMatrix, j::Integer, upper::Bool)
112-
113-
Returns views to various bits of the lower triangle of `A` according to the
114-
`level2partition` procedure defined in [1] if `upper` is `false`. If `upper` is `true` then
115-
the transposed views are returned from the upper triangle of `A`.
116-
117-
[1]: "Differentiation of the Cholesky decomposition", Murray 2016
118-
"""
119-
function level2partition(A::AbstractMatrix, j::Integer, upper::Bool)
120-
n = checksquare(A)
121-
@boundscheck checkbounds(1:n, j)
122-
if upper
123-
r = view(A, 1:j-1, j)
124-
d = view(A, j, j)
125-
B = view(A, 1:j-1, j+1:n)
126-
c = view(A, j, j+1:n)
127-
else
128-
r = view(A, j, 1:j-1)
129-
d = view(A, j, j)
130-
B = view(A, j+1:n, 1:j-1)
131-
c = view(A, j+1:n, j)
132-
end
133-
return r, d, B, c
134-
end
135-
136-
"""
137-
level3partition(A::AbstractMatrix, j::Integer, k::Integer, upper::Bool)
138-
139-
Returns views to various bits of the lower triangle of `A` according to the
140-
`level3partition` procedure defined in [1] if `upper` is `false`. If `upper` is `true` then
141-
the transposed views are returned from the upper triangle of `A`.
142-
143-
[1]: "Differentiation of the Cholesky decomposition", Murray 2016
144-
"""
145-
function level3partition(A::AbstractMatrix, j::Integer, k::Integer, upper::Bool)
146-
n = checksquare(A)
147-
@boundscheck checkbounds(1:n, j)
148-
if upper
149-
R = view(A, 1:j-1, j:k)
150-
D = view(A, j:k, j:k)
151-
B = view(A, 1:j-1, k+1:n)
152-
C = view(A, j:k, k+1:n)
153-
else
154-
R = view(A, j:k, 1:j-1)
155-
D = view(A, j:k, j:k)
156-
B = view(A, k+1:n, 1:j-1)
157-
C = view(A, k+1:n, j:k)
158-
end
159-
return R, D, B, C
160-
end
161-
162-
"""
163-
chol_unblocked_rev!(Ā::AbstractMatrix, L::AbstractMatrix, upper::Bool)
164-
165-
Compute the reverse-mode sensitivities of the Cholesky factorization in an unblocked manner.
166-
If `upper` is `false`, then the sensitivites are computed from and stored in the lower triangle
167-
of `Ā` and `L` respectively. If `upper` is `true` then they are computed and stored in the
168-
upper triangles. If at input `upper` is `false` and `tril(Ā) = L̄`, at output
169-
`tril(Ā) = tril(Σ̄)`, where `Σ = LLᵀ`. Analogously, if at input `upper` is `true` and
170-
`triu(Ā) = triu(Ū)`, at output `triu(Ā) = triu(Σ̄)` where `Σ = UᵀU`.
171-
"""
172-
function chol_unblocked_rev!(Σ̄::AbstractMatrix{T}, L::AbstractMatrix{T}, upper::Bool) where T<:Real
173-
n = checksquare(Σ̄)
174-
j = n
175-
@inbounds for _ in 1:n
176-
r, d, B, c = level2partition(L, j, upper)
177-
r̄, d̄, B̄, c̄ = level2partition(Σ̄, j, upper)
178-
179-
# d̄ <- d̄ - c'c̄ / d.
180-
d̄[1] -= dot(c, c̄) / d[1]
181-
182-
# [d̄ c̄'] <- [d̄ c̄'] / d.
183-
./= d
184-
./= d
185-
186-
# r̄ <- r̄ - [d̄ c̄'] [r' B']'.
187-
= axpy!(-Σ̄[j,j], r, r̄)
188-
= gemv!(upper ? 'n' : 'T', -one(T), B, c̄, one(T), r̄)
189-
190-
# B̄ <- B̄ - c̄ r.
191-
= upper ? ger!(-one(T), r, c̄, B̄) : ger!(-one(T), c̄, r, B̄)
192-
./= 2
193-
j -= 1
194-
end
195-
return (upper ? triu! : tril!)(Σ̄)
196-
end
197-
198-
function chol_unblocked_rev(Σ̄::AbstractMatrix, L::AbstractMatrix, upper::Bool)
199-
return chol_unblocked_rev!(copy(Σ̄), L, upper)
200-
end
201-
202-
"""
203-
chol_blocked_rev!(Σ̄::StridedMatrix, L::StridedMatrix, nb::Integer, upper::Bool)
204-
205-
Compute the sensitivities of the Cholesky factorization using a blocked, cache-friendly
206-
procedure. `Σ̄` are the sensitivities of `L`, and will be transformed into the sensitivities
207-
of `Σ`, where `Σ = LLᵀ`. `nb` is the block size to use. If the upper triangle has been used
208-
to represent the factorization, that is `Σ = UᵀU` where `U := Lᵀ`, then this should be
209-
indicated by passing `upper = true`.
210-
"""
211-
function chol_blocked_rev!(Σ̄::StridedMatrix{T}, L::StridedMatrix{T}, nb::Integer, upper::Bool) where T<:Real
212-
n = checksquare(Σ̄)
213-
tmp = Matrix{T}(undef, nb, nb)
214-
k = n
215-
if upper
216-
@inbounds for _ in 1:nb:n
217-
j = max(1, k - nb + 1)
218-
R, D, B, C = level3partition(L, j, k, true)
219-
R̄, D̄, B̄, C̄ = level3partition(Σ̄, j, k, true)
220-
221-
= trsm!('L', 'U', 'N', 'N', one(T), D, C̄)
222-
gemm!('N', 'N', -one(T), R, C̄, one(T), B̄)
223-
gemm!('N', 'T', -one(T), C, C̄, one(T), D̄)
224-
chol_unblocked_rev!(D̄, D, true)
225-
gemm!('N', 'T', -one(T), B, C̄, one(T), R̄)
226-
if size(D̄, 1) == nb
227-
tmp = axpy!(one(T), D̄, transpose!(tmp, D̄))
228-
gemm!('N', 'N', -one(T), R, tmp, one(T), R̄)
229-
else
230-
gemm!('N', 'N', -one(T), R, D̄ +', one(T), R̄)
231-
end
232-
233-
k -= nb
234-
end
235-
return triu!(Σ̄)
236-
else
237-
@inbounds for _ in 1:nb:n
238-
j = max(1, k - nb + 1)
239-
R, D, B, C = level3partition(L, j, k, false)
240-
R̄, D̄, B̄, C̄ = level3partition(Σ̄, j, k, false)
241-
242-
= trsm!('R', 'L', 'N', 'N', one(T), D, C̄)
243-
gemm!('N', 'N', -one(T), C̄, R, one(T), B̄)
244-
gemm!('T', 'N', -one(T), C̄, C, one(T), D̄)
245-
chol_unblocked_rev!(D̄, D, false)
246-
gemm!('T', 'N', -one(T), C̄, B, one(T), R̄)
247-
if size(D̄, 1) == nb
248-
tmp = axpy!(one(T), D̄, transpose!(tmp, D̄))
249-
gemm!('N', 'N', -one(T), tmp, R, one(T), R̄)
250-
else
251-
gemm!('N', 'N', -one(T), D̄ +', R, one(T), R̄)
252-
end
253-
254-
k -= nb
255-
end
256-
return tril!(Σ̄)
257-
end
258-
end
259-
260-
function chol_blocked_rev(Σ̄::AbstractMatrix, L::AbstractMatrix, nb::Integer, upper::Bool)
261-
# Convert to `Matrix`s because blas functions require StridedMatrix input.
262-
return chol_blocked_rev!(Matrix(Σ̄), Matrix(L), nb, upper)
263-
end

test/rulesets/LinearAlgebra/factorization.jl

Lines changed: 56 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,15 @@
1-
using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblocked_rev
1+
function FiniteDifferences.to_vec(C::Cholesky)
2+
C_vec, factors_from_vec = to_vec(C.factors)
3+
function cholesky_from_vec(v)
4+
return Cholesky(factors_from_vec(v), C.uplo, C.info)
5+
end
6+
return C_vec, cholesky_from_vec
7+
end
8+
9+
function FiniteDifferences.to_vec(x::Val)
10+
Val_from_vec(v) = x
11+
return Bool[], Val_from_vec
12+
end
213

314
@testset "Factorizations" begin
415
@testset "svd" begin
@@ -73,69 +84,54 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo
7384
@test ChainRules._eyesubx!(copy(X)) I - X
7485
end
7586
end
87+
88+
# These tests are generally a bit tricky to write because FiniteDifferences doesn't
89+
# have fantastic support for this stuff at the minute.
7690
@testset "cholesky" begin
77-
@testset "the thing" begin
78-
X = generate_well_conditioned_matrix(10)
79-
V = generate_well_conditioned_matrix(10)
80-
F, dX_pullback = rrule(cholesky, X)
81-
for p in [:U, :L]
82-
Y, dF_pullback = rrule(getproperty, F, p)
83-
= (p === :U ? UpperTriangular : LowerTriangular)(randn(size(Y)))
84-
(dself, dF, dp) = dF_pullback(Ȳ)
85-
@test dself === NO_FIELDS
86-
@test dp === DoesNotExist()
91+
@testset "Real" begin
92+
C = cholesky(rand() + 0.1)
93+
ΔC = Composite{typeof(C)}((factors=rand_tangent(C.factors)))
94+
rrule_test(cholesky, ΔC, (rand() + 0.1, randn()))
95+
end
96+
@testset "Diagonal{<:Real}" begin
97+
D = Diagonal(rand(5) .+ 0.1)
98+
C = cholesky(D)
99+
ΔC = Composite{typeof(C)}((factors=Diagonal(randn(5))))
100+
rrule_test(cholesky, ΔC, (D, Diagonal(randn(5))), (Val(false), nothing))
101+
end
87102

88-
# NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp`
89-
# machinery from FiniteDifferences because that isn't set up to respect
90-
# necessary special properties of the input. In the case of the Cholesky
91-
# factorization, we need the input to be Hermitian.
92-
ΔF = unthunk(dF)
93-
_, dX = dX_pullback(ΔF)
94-
X̄_ad = dot(unthunk(dX), V)
95-
X̄_fd = _fdm(0.0) do ε
96-
dot(Ȳ, getproperty(cholesky(X .+ ε .* V), p))
97-
end
98-
@test X̄_ad X̄_fd rtol=1e-6 atol=1e-6
103+
X = generate_well_conditioned_matrix(10)
104+
V = generate_well_conditioned_matrix(10)
105+
F, dX_pullback = rrule(cholesky, X, Val(false))
106+
@testset "uplo=$p" for p in [:U, :L]
107+
Y, dF_pullback = rrule(getproperty, F, p)
108+
= (p === :U ? UpperTriangular : LowerTriangular)(randn(size(Y)))
109+
(dself, dF, dp) = dF_pullback(Ȳ)
110+
@test dself === NO_FIELDS
111+
@test dp === DoesNotExist()
112+
113+
# NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp`
114+
# machinery from FiniteDifferences because that isn't set up to respect
115+
# necessary special properties of the input. In the case of the Cholesky
116+
# factorization, we need the input to be Hermitian.
117+
ΔF = unthunk(dF)
118+
_, dX = dX_pullback(ΔF)
119+
X̄_ad = dot(unthunk(dX), V)
120+
X̄_fd = central_fdm(5, 1)(0.000_001) do ε
121+
dot(Ȳ, getproperty(cholesky(X .+ ε .* V), p))
99122
end
123+
@test X̄_ad X̄_fd rtol=1e-4
100124
end
101-
@testset "helper functions" begin
102-
A = randn(5, 5)
103-
r, d, B2, c = level2partition(A, 4, false)
104-
R, D, B3, C = level3partition(A, 4, 4, false)
105-
@test all(r .== R')
106-
@test all(d .== D)
107-
@test B2[1] == B3[1]
108-
@test all(c .== C)
109-
110-
# Check that level 2 partition with `upper == true` is consistent with `false`
111-
rᵀ, dᵀ, B2ᵀ, cᵀ = level2partition(transpose(A), 4, true)
112-
@test r == rᵀ
113-
@test d == dᵀ
114-
@test B2' == B2ᵀ
115-
@test c == cᵀ
116-
117-
# Check that level 3 partition with `upper == true` is consistent with `false`
118-
R, D, B3, C = level3partition(A, 2, 4, false)
119-
Rᵀ, Dᵀ, B3ᵀ, Cᵀ = level3partition(transpose(A), 2, 4, true)
120-
@test transpose(R) == Rᵀ
121-
@test transpose(D) == Dᵀ
122-
@test transpose(B3) == B3ᵀ
123-
@test transpose(C) == Cᵀ
124-
125-
A = Matrix(LowerTriangular(randn(10, 10)))
126-
= Matrix(LowerTriangular(randn(10, 10)))
127-
# NOTE: BLAS gets angry if we don't materialize the Transpose objects first
128-
B = Matrix(transpose(A))
129-
= Matrix(transpose(Ā))
130-
@test chol_unblocked_rev(Ā, A, false) chol_blocked_rev(Ā, A, 1, false)
131-
@test chol_unblocked_rev(Ā, A, false) chol_blocked_rev(Ā, A, 3, false)
132-
@test chol_unblocked_rev(Ā, A, false) chol_blocked_rev(Ā, A, 5, false)
133-
@test chol_unblocked_rev(Ā, A, false) chol_blocked_rev(Ā, A, 10, false)
134-
@test chol_unblocked_rev(Ā, A, false) transpose(chol_unblocked_rev(B̄, B, true))
135-
136-
@test chol_unblocked_rev(B̄, B, true) chol_blocked_rev(B̄, B, 1, true)
137-
@test chol_unblocked_rev(B̄, B, true) chol_blocked_rev(B̄, B, 5, true)
138-
@test chol_unblocked_rev(B̄, B, true) chol_blocked_rev(B̄, B, 10, true)
125+
126+
# Ensure that cotangents of cholesky(::StridedMatrix) and
127+
# (cholesky ∘ Symmetric)(::StridedMatrix) are equal.
128+
@testset "Symmetric" begin
129+
X_symmetric, sym_back = rrule(Symmetric, X, :U)
130+
C, chol_back_sym = rrule(cholesky, X_symmetric, Val(false))
131+
132+
Δ = Composite{typeof(C)}((U=UpperTriangular(randn(size(X)))))
133+
ΔX_symmetric = chol_back_sym(Δ)[2]
134+
@test sym_back(ΔX_symmetric)[2] dX_pullback(Δ)[2]
139135
end
140136
end
141137
end

0 commit comments

Comments
 (0)