Skip to content

Commit 3b7836d

Browse files
authored
Add an rrule for the Cholesky decomposition (#44)
This is a direct port of the code from Nabla.
1 parent d8baa8f commit 3b7836d

File tree

2 files changed

+250
-0
lines changed

2 files changed

+250
-0
lines changed

src/rules/linalg/factorization.jl

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
using LinearAlgebra: checksquare
2+
using LinearAlgebra.BLAS: gemv, gemv!, gemm!, trsm!, axpy!, ger!
3+
14
#####
25
##### `svd`
36
#####
@@ -82,3 +85,187 @@ function _add!(X::AbstractMatrix{T}, Y::AbstractMatrix{T}) where T<:Real
8285
end
8386
X
8487
end
88+
89+
#####
90+
##### `cholesky`
91+
#####
92+
93+
function rrule(::typeof(cholesky), X::AbstractMatrix{<:Real})
94+
F = cholesky(X)
95+
∂X = Rule(Ȳ->chol_blocked_rev(Matrix(Ȳ), Matrix(F.U), 25, true))
96+
return F, ∂X
97+
end
98+
99+
function rrule(::typeof(getproperty), F::Cholesky, x::Symbol)
100+
if x === :U
101+
if F.uplo === 'U'
102+
∂F =->UpperTriangular(Ȳ)
103+
else
104+
∂F =->LowerTriangular(Ȳ')
105+
end
106+
elseif x === :L
107+
if F.uplo === 'L'
108+
∂F =->LowerTriangular(Ȳ)
109+
else
110+
∂F =->UpperTriangular(Ȳ')
111+
end
112+
end
113+
return getproperty(F, x), (Rule(∂F), DNERule())
114+
end
115+
116+
# See "Differentiation of the Cholesky decomposition" (Murray 2016), pages 5-9 in particular,
117+
# for derivations. Here we're implementing the algorithms and their transposes.
118+
119+
"""
120+
level2partition(A::AbstractMatrix, j::Integer, upper::Bool)
121+
122+
Returns views to various bits of the lower triangle of `A` according to the
123+
`level2partition` procedure defined in [1] if `upper` is `false`. If `upper` is `true` then
124+
the transposed views are returned from the upper triangle of `A`.
125+
126+
[1]: "Differentiation of the Cholesky decomposition", Murray 2016
127+
"""
128+
function level2partition(A::AbstractMatrix, j::Integer, upper::Bool)
129+
n = checksquare(A)
130+
@boundscheck checkbounds(1:n, j)
131+
if upper
132+
r = view(A, 1:j-1, j)
133+
d = view(A, j, j)
134+
B = view(A, 1:j-1, j+1:n)
135+
c = view(A, j, j+1:n)
136+
else
137+
r = view(A, j, 1:j-1)
138+
d = view(A, j, j)
139+
B = view(A, j+1:n, 1:j-1)
140+
c = view(A, j+1:n, j)
141+
end
142+
return r, d, B, c
143+
end
144+
145+
"""
146+
level3partition(A::AbstractMatrix, j::Integer, k::Integer, upper::Bool)
147+
148+
Returns views to various bits of the lower triangle of `A` according to the
149+
`level3partition` procedure defined in [1] if `upper` is `false`. If `upper` is `true` then
150+
the transposed views are returned from the upper triangle of `A`.
151+
152+
[1]: "Differentiation of the Cholesky decomposition", Murray 2016
153+
"""
154+
function level3partition(A::AbstractMatrix, j::Integer, k::Integer, upper::Bool)
155+
n = checksquare(A)
156+
@boundscheck checkbounds(1:n, j)
157+
if upper
158+
R = view(A, 1:j-1, j:k)
159+
D = view(A, j:k, j:k)
160+
B = view(A, 1:j-1, k+1:n)
161+
C = view(A, j:k, k+1:n)
162+
else
163+
R = view(A, j:k, 1:j-1)
164+
D = view(A, j:k, j:k)
165+
B = view(A, k+1:n, 1:j-1)
166+
C = view(A, k+1:n, j:k)
167+
end
168+
return R, D, B, C
169+
end
170+
171+
"""
172+
chol_unblocked_rev!(Ā::AbstractMatrix, L::AbstractMatrix, upper::Bool)
173+
174+
Compute the reverse-mode sensitivities of the Cholesky factorization in an unblocked manner.
175+
If `upper` is `false`, then the sensitivites are computed from and stored in the lower triangle
176+
of `Ā` and `L` respectively. If `upper` is `true` then they are computed and stored in the
177+
upper triangles. If at input `upper` is `false` and `tril(Ā) = L̄`, at output
178+
`tril(Ā) = tril(Σ̄)`, where `Σ = LLᵀ`. Analogously, if at input `upper` is `true` and
179+
`triu(Ā) = triu(Ū)`, at output `triu(Ā) = triu(Σ̄)` where `Σ = UᵀU`.
180+
"""
181+
function chol_unblocked_rev!(Σ̄::AbstractMatrix{T}, L::AbstractMatrix{T}, upper::Bool) where T<:Real
182+
n = checksquare(Σ̄)
183+
j = n
184+
@inbounds for _ in 1:n
185+
r, d, B, c = level2partition(L, j, upper)
186+
r̄, d̄, B̄, c̄ = level2partition(Σ̄, j, upper)
187+
188+
# d̄ <- d̄ - c'c̄ / d.
189+
d̄[1] -= dot(c, c̄) / d[1]
190+
191+
# [d̄ c̄'] <- [d̄ c̄'] / d.
192+
./= d
193+
./= d
194+
195+
# r̄ <- r̄ - [d̄ c̄'] [r' B']'.
196+
= axpy!(-Σ̄[j,j], r, r̄)
197+
= gemv!(upper ? 'n' : 'T', -one(T), B, c̄, one(T), r̄)
198+
199+
# B̄ <- B̄ - c̄ r.
200+
= upper ? ger!(-one(T), r, c̄, B̄) : ger!(-one(T), c̄, r, B̄)
201+
./= 2
202+
j -= 1
203+
end
204+
return (upper ? triu! : tril!)(Σ̄)
205+
end
206+
207+
function chol_unblocked_rev(Σ̄::AbstractMatrix, L::AbstractMatrix, upper::Bool)
208+
return chol_unblocked_rev!(copy(Σ̄), L, upper)
209+
end
210+
211+
"""
212+
chol_blocked_rev!(Σ̄::AbstractMatrix, L::AbstractMatrix, nb::Integer, upper::Bool)
213+
214+
Compute the sensitivities of the Cholesky factorization using a blocked, cache-friendly
215+
procedure. `Σ̄` are the sensitivities of `L`, and will be transformed into the sensitivities
216+
of `Σ`, where `Σ = LLᵀ`. `nb` is the block size to use. If the upper triangle has been used
217+
to represent the factorization, that is `Σ = UᵀU` where `U := Lᵀ`, then this should be
218+
indicated by passing `upper = true`.
219+
"""
220+
function chol_blocked_rev!(Σ̄::AbstractMatrix{T}, L::AbstractMatrix{T}, nb::Integer, upper::Bool) where T<:Real
221+
n = checksquare(Σ̄)
222+
tmp = Matrix{T}(undef, nb, nb)
223+
k = n
224+
if upper
225+
@inbounds for _ in 1:nb:n
226+
j = max(1, k - nb + 1)
227+
R, D, B, C = level3partition(L, j, k, true)
228+
R̄, D̄, B̄, C̄ = level3partition(Σ̄, j, k, true)
229+
230+
= trsm!('L', 'U', 'N', 'N', one(T), D, C̄)
231+
gemm!('N', 'N', -one(T), R, C̄, one(T), B̄)
232+
gemm!('N', 'T', -one(T), C, C̄, one(T), D̄)
233+
chol_unblocked_rev!(D̄, D, true)
234+
gemm!('N', 'T', -one(T), B, C̄, one(T), R̄)
235+
if size(D̄, 1) == nb
236+
tmp = axpy!(one(T), D̄, transpose!(tmp, D̄))
237+
gemm!('N', 'N', -one(T), R, tmp, one(T), R̄)
238+
else
239+
gemm!('N', 'N', -one(T), R, D̄ +', one(T), R̄)
240+
end
241+
242+
k -= nb
243+
end
244+
return triu!(Σ̄)
245+
else
246+
@inbounds for _ in 1:nb:n
247+
j = max(1, k - nb + 1)
248+
R, D, B, C = level3partition(L, j, k, false)
249+
R̄, D̄, B̄, C̄ = level3partition(Σ̄, j, k, false)
250+
251+
= trsm!('R', 'L', 'N', 'N', one(T), D, C̄)
252+
gemm!('N', 'N', -one(T), C̄, R, one(T), B̄)
253+
gemm!('T', 'N', -one(T), C̄, C, one(T), D̄)
254+
chol_unblocked_rev!(D̄, D, false)
255+
gemm!('T', 'N', -one(T), C̄, B, one(T), R̄)
256+
if size(D̄, 1) == nb
257+
tmp = axpy!(one(T), D̄, transpose!(tmp, D̄))
258+
gemm!('N', 'N', -one(T), tmp, R, one(T), R̄)
259+
else
260+
gemm!('N', 'N', -one(T), D̄ +', R, one(T), R̄)
261+
end
262+
263+
k -= nb
264+
end
265+
return tril!(Σ̄)
266+
end
267+
end
268+
269+
function chol_blocked_rev(Σ̄::AbstractMatrix, L::AbstractMatrix, nb::Integer, upper::Bool)
270+
return chol_blocked_rev!(copy(Σ̄), L, nb, upper)
271+
end

test/rules/linalg/factorization.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblocked_rev
2+
13
@testset "Factorizations" begin
24
@testset "svd" begin
35
rng = MersenneTwister(2)
@@ -35,4 +37,65 @@
3537
@test ChainRules._add!(copy(X), Y) X + Y
3638
end
3739
end
40+
@testset "cholesky" begin
41+
rng = MersenneTwister(4)
42+
@testset "the thing" begin
43+
X = generate_well_conditioned_matrix(rng, 10)
44+
V = generate_well_conditioned_matrix(rng, 10)
45+
F, dX = rrule(cholesky, X)
46+
for p in [:U, :L]
47+
Y, (dF, dp) = rrule(getproperty, F, p)
48+
@test dp isa ChainRules.DNERule
49+
= (p === :U ? UpperTriangular : LowerTriangular)(randn(rng, size(Y)))
50+
# NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp`
51+
# machinery from FDM because that isn't set up to respect necessary special
52+
# properties of the input. In the case of the Cholesky factorization, we
53+
# need the input to be Hermitian.
54+
X̄_ad = dot(dX(dF(Ȳ)), V)
55+
X̄_fd = central_fdm(5, 1)() do ε
56+
dot(Ȳ, getproperty(cholesky(X .+ ε .* V), p))
57+
end
58+
@test X̄_ad X̄_fd rtol=1e-6 atol=1e-6
59+
end
60+
end
61+
@testset "helper functions" begin
62+
A = randn(rng, 5, 5)
63+
r, d, B2, c = level2partition(A, 4, false)
64+
R, D, B3, C = level3partition(A, 4, 4, false)
65+
@test all(r .== R')
66+
@test all(d .== D)
67+
@test B2[1] == B3[1]
68+
@test all(c .== C)
69+
70+
# Check that level 2 partition with `upper == true` is consistent with `false`
71+
rᵀ, dᵀ, B2ᵀ, cᵀ = level2partition(transpose(A), 4, true)
72+
@test r == rᵀ
73+
@test d == dᵀ
74+
@test B2' == B2ᵀ
75+
@test c == cᵀ
76+
77+
# Check that level 3 partition with `upper == true` is consistent with `false`
78+
R, D, B3, C = level3partition(A, 2, 4, false)
79+
Rᵀ, Dᵀ, B3ᵀ, Cᵀ = level3partition(transpose(A), 2, 4, true)
80+
@test transpose(R) == Rᵀ
81+
@test transpose(D) == Dᵀ
82+
@test transpose(B3) == B3ᵀ
83+
@test transpose(C) == Cᵀ
84+
85+
A = Matrix(LowerTriangular(randn(rng, 10, 10)))
86+
= Matrix(LowerTriangular(randn(rng, 10, 10)))
87+
# NOTE: BLAS gets angry if we don't materialize the Transpose objects first
88+
B = Matrix(transpose(A))
89+
= Matrix(transpose(Ā))
90+
@test chol_unblocked_rev(Ā, A, false) chol_blocked_rev(Ā, A, 1, false)
91+
@test chol_unblocked_rev(Ā, A, false) chol_blocked_rev(Ā, A, 3, false)
92+
@test chol_unblocked_rev(Ā, A, false) chol_blocked_rev(Ā, A, 5, false)
93+
@test chol_unblocked_rev(Ā, A, false) chol_blocked_rev(Ā, A, 10, false)
94+
@test chol_unblocked_rev(Ā, A, false) transpose(chol_unblocked_rev(B̄, B, true))
95+
96+
@test chol_unblocked_rev(B̄, B, true) chol_blocked_rev(B̄, B, 1, true)
97+
@test chol_unblocked_rev(B̄, B, true) chol_blocked_rev(B̄, B, 5, true)
98+
@test chol_unblocked_rev(B̄, B, true) chol_blocked_rev(B̄, B, 10, true)
99+
end
100+
end
38101
end

0 commit comments

Comments
 (0)