70
70
# #### `cholesky`
71
71
# ####
72
72
73
- function rrule (:: typeof (cholesky), X:: AbstractMatrix{<:Real} )
74
- F = cholesky (X)
75
- function cholesky_pullback (Ȳ:: Composite )
76
- ∂X = if F. uplo === ' U'
77
- chol_blocked_rev (Ȳ. U, F. U, 25 , true )
78
- else
79
- chol_blocked_rev (Ȳ. L, F. L, 25 , false )
80
- end
81
- return (NO_FIELDS, ∂X)
73
+ function rrule (:: typeof (cholesky), A:: Real , uplo:: Symbol = :U )
74
+ C = cholesky (A, uplo)
75
+ function cholesky_pullback (ΔC:: Composite )
76
+ return NO_FIELDS, ΔC. factors[1 , 1 ] / (2 * C. U[1 , 1 ]), DoesNotExist ()
82
77
end
83
- return F, cholesky_pullback
78
+ return C, cholesky_pullback
79
+ end
80
+
81
+ function rrule (:: typeof (cholesky), A:: Diagonal{<:Real} , :: Val{false} ; check:: Bool = true )
82
+ C = cholesky (A, Val (false ); check= check)
83
+ function cholesky_pullback (ΔC:: Composite )
84
+ Ā = Diagonal (diag (ΔC. factors) .* inv .(2 .* C. factors. diag))
85
+ return NO_FIELDS, Ā, DoesNotExist ()
86
+ end
87
+ return C, cholesky_pullback
88
+ end
89
+
90
+ # The appropriate cotangent is different depending upon whether A is Symmetric / Hermitian,
91
+ # or just a StridedMatrix.
92
+ # Implementation due to Seeger, Matthias, et al. "Auto-differentiating linear algebra."
93
+ function rrule (
94
+ :: typeof (cholesky),
95
+ A:: LinearAlgebra.HermOrSym{<:LinearAlgebra.BlasReal, <:StridedMatrix} ,
96
+ :: Val{false} ;
97
+ check:: Bool = true ,
98
+ )
99
+ C = cholesky (A, Val (false ); check= check)
100
+ function cholesky_pullback (ΔC:: Composite )
101
+ Ā, U = _cholesky_pullback_shared_code (C, ΔC)
102
+ Ā = BLAS. trsm! (' R' , ' U' , ' C' , ' N' , one (eltype (Ā)) / 2 , U. data, Ā)
103
+ return NO_FIELDS, _symhermtype (A)(Ā), DoesNotExist ()
104
+ end
105
+ return C, cholesky_pullback
106
+ end
107
+
108
+ function rrule (
109
+ :: typeof (cholesky),
110
+ A:: StridedMatrix{<:LinearAlgebra.BlasReal} ,
111
+ :: Val{false} ;
112
+ check:: Bool = true ,
113
+ )
114
+ C = cholesky (A, Val (false ); check= check)
115
+ function cholesky_pullback (ΔC:: Composite )
116
+ Ā, U = _cholesky_pullback_shared_code (C, ΔC)
117
+ Ā = BLAS. trsm! (' R' , ' U' , ' C' , ' N' , one (eltype (Ā)), U. data, Ā)
118
+ idx = diagind (Ā)
119
+ @views Ā[idx] .= real .(Ā[idx]) ./ 2
120
+ return (NO_FIELDS, UpperTriangular (Ā), DoesNotExist ())
121
+ end
122
+ return C, cholesky_pullback
123
+ end
124
+
125
+ function _cholesky_pullback_shared_code (C, ΔC)
126
+ U = C. U
127
+ Ū = ΔC. U
128
+ Ā = similar (U. data)
129
+ Ā = mul! (Ā, Ū, U' )
130
+ Ā = LinearAlgebra. copytri! (Ā, ' U' , true )
131
+ Ā = ldiv! (U, Ā)
132
+ return Ā, U
84
133
end
85
134
86
- function rrule (:: typeof (getproperty), F:: T , x:: Symbol ) where T <: Cholesky
135
+ function rrule (:: typeof (getproperty), F:: T , x:: Symbol ) where { T <: Cholesky }
87
136
function getproperty_cholesky_pullback (Ȳ)
88
137
C = Composite{T}
89
138
∂F = if x === :U
@@ -103,161 +152,3 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: Cholesky
103
152
end
104
153
return getproperty (F, x), getproperty_cholesky_pullback
105
154
end
106
-
107
- # See "Differentiation of the Cholesky decomposition" (Murray 2016), pages 5-9 in particular,
108
- # for derivations. Here we're implementing the algorithms and their transposes.
109
-
110
- """
111
- level2partition(A::AbstractMatrix, j::Integer, upper::Bool)
112
-
113
- Returns views to various bits of the lower triangle of `A` according to the
114
- `level2partition` procedure defined in [1] if `upper` is `false`. If `upper` is `true` then
115
- the transposed views are returned from the upper triangle of `A`.
116
-
117
- [1]: "Differentiation of the Cholesky decomposition", Murray 2016
118
- """
119
- function level2partition (A:: AbstractMatrix , j:: Integer , upper:: Bool )
120
- n = checksquare (A)
121
- @boundscheck checkbounds (1 : n, j)
122
- if upper
123
- r = view (A, 1 : j- 1 , j)
124
- d = view (A, j, j)
125
- B = view (A, 1 : j- 1 , j+ 1 : n)
126
- c = view (A, j, j+ 1 : n)
127
- else
128
- r = view (A, j, 1 : j- 1 )
129
- d = view (A, j, j)
130
- B = view (A, j+ 1 : n, 1 : j- 1 )
131
- c = view (A, j+ 1 : n, j)
132
- end
133
- return r, d, B, c
134
- end
135
-
136
- """
137
- level3partition(A::AbstractMatrix, j::Integer, k::Integer, upper::Bool)
138
-
139
- Returns views to various bits of the lower triangle of `A` according to the
140
- `level3partition` procedure defined in [1] if `upper` is `false`. If `upper` is `true` then
141
- the transposed views are returned from the upper triangle of `A`.
142
-
143
- [1]: "Differentiation of the Cholesky decomposition", Murray 2016
144
- """
145
- function level3partition (A:: AbstractMatrix , j:: Integer , k:: Integer , upper:: Bool )
146
- n = checksquare (A)
147
- @boundscheck checkbounds (1 : n, j)
148
- if upper
149
- R = view (A, 1 : j- 1 , j: k)
150
- D = view (A, j: k, j: k)
151
- B = view (A, 1 : j- 1 , k+ 1 : n)
152
- C = view (A, j: k, k+ 1 : n)
153
- else
154
- R = view (A, j: k, 1 : j- 1 )
155
- D = view (A, j: k, j: k)
156
- B = view (A, k+ 1 : n, 1 : j- 1 )
157
- C = view (A, k+ 1 : n, j: k)
158
- end
159
- return R, D, B, C
160
- end
161
-
162
- """
163
- chol_unblocked_rev!(Ā::AbstractMatrix, L::AbstractMatrix, upper::Bool)
164
-
165
- Compute the reverse-mode sensitivities of the Cholesky factorization in an unblocked manner.
166
- If `upper` is `false`, then the sensitivites are computed from and stored in the lower triangle
167
- of `Ā` and `L` respectively. If `upper` is `true` then they are computed and stored in the
168
- upper triangles. If at input `upper` is `false` and `tril(Ā) = L̄`, at output
169
- `tril(Ā) = tril(Σ̄)`, where `Σ = LLᵀ`. Analogously, if at input `upper` is `true` and
170
- `triu(Ā) = triu(Ū)`, at output `triu(Ā) = triu(Σ̄)` where `Σ = UᵀU`.
171
- """
172
- function chol_unblocked_rev! (Σ̄:: AbstractMatrix{T} , L:: AbstractMatrix{T} , upper:: Bool ) where T<: Real
173
- n = checksquare (Σ̄)
174
- j = n
175
- @inbounds for _ in 1 : n
176
- r, d, B, c = level2partition (L, j, upper)
177
- r̄, d̄, B̄, c̄ = level2partition (Σ̄, j, upper)
178
-
179
- # d̄ <- d̄ - c'c̄ / d.
180
- d̄[1 ] -= dot (c, c̄) / d[1 ]
181
-
182
- # [d̄ c̄'] <- [d̄ c̄'] / d.
183
- d̄ ./= d
184
- c̄ ./= d
185
-
186
- # r̄ <- r̄ - [d̄ c̄'] [r' B']'.
187
- r̄ = axpy! (- Σ̄[j,j], r, r̄)
188
- r̄ = gemv! (upper ? ' n' : ' T' , - one (T), B, c̄, one (T), r̄)
189
-
190
- # B̄ <- B̄ - c̄ r.
191
- B̄ = upper ? ger! (- one (T), r, c̄, B̄) : ger! (- one (T), c̄, r, B̄)
192
- d̄ ./= 2
193
- j -= 1
194
- end
195
- return (upper ? triu! : tril!)(Σ̄)
196
- end
197
-
198
- function chol_unblocked_rev (Σ̄:: AbstractMatrix , L:: AbstractMatrix , upper:: Bool )
199
- return chol_unblocked_rev! (copy (Σ̄), L, upper)
200
- end
201
-
202
- """
203
- chol_blocked_rev!(Σ̄::StridedMatrix, L::StridedMatrix, nb::Integer, upper::Bool)
204
-
205
- Compute the sensitivities of the Cholesky factorization using a blocked, cache-friendly
206
- procedure. `Σ̄` are the sensitivities of `L`, and will be transformed into the sensitivities
207
- of `Σ`, where `Σ = LLᵀ`. `nb` is the block size to use. If the upper triangle has been used
208
- to represent the factorization, that is `Σ = UᵀU` where `U := Lᵀ`, then this should be
209
- indicated by passing `upper = true`.
210
- """
211
- function chol_blocked_rev! (Σ̄:: StridedMatrix{T} , L:: StridedMatrix{T} , nb:: Integer , upper:: Bool ) where T<: Real
212
- n = checksquare (Σ̄)
213
- tmp = Matrix {T} (undef, nb, nb)
214
- k = n
215
- if upper
216
- @inbounds for _ in 1 : nb: n
217
- j = max (1 , k - nb + 1 )
218
- R, D, B, C = level3partition (L, j, k, true )
219
- R̄, D̄, B̄, C̄ = level3partition (Σ̄, j, k, true )
220
-
221
- C̄ = trsm! (' L' , ' U' , ' N' , ' N' , one (T), D, C̄)
222
- gemm! (' N' , ' N' , - one (T), R, C̄, one (T), B̄)
223
- gemm! (' N' , ' T' , - one (T), C, C̄, one (T), D̄)
224
- chol_unblocked_rev! (D̄, D, true )
225
- gemm! (' N' , ' T' , - one (T), B, C̄, one (T), R̄)
226
- if size (D̄, 1 ) == nb
227
- tmp = axpy! (one (T), D̄, transpose! (tmp, D̄))
228
- gemm! (' N' , ' N' , - one (T), R, tmp, one (T), R̄)
229
- else
230
- gemm! (' N' , ' N' , - one (T), R, D̄ + D̄' , one (T), R̄)
231
- end
232
-
233
- k -= nb
234
- end
235
- return triu! (Σ̄)
236
- else
237
- @inbounds for _ in 1 : nb: n
238
- j = max (1 , k - nb + 1 )
239
- R, D, B, C = level3partition (L, j, k, false )
240
- R̄, D̄, B̄, C̄ = level3partition (Σ̄, j, k, false )
241
-
242
- C̄ = trsm! (' R' , ' L' , ' N' , ' N' , one (T), D, C̄)
243
- gemm! (' N' , ' N' , - one (T), C̄, R, one (T), B̄)
244
- gemm! (' T' , ' N' , - one (T), C̄, C, one (T), D̄)
245
- chol_unblocked_rev! (D̄, D, false )
246
- gemm! (' T' , ' N' , - one (T), C̄, B, one (T), R̄)
247
- if size (D̄, 1 ) == nb
248
- tmp = axpy! (one (T), D̄, transpose! (tmp, D̄))
249
- gemm! (' N' , ' N' , - one (T), tmp, R, one (T), R̄)
250
- else
251
- gemm! (' N' , ' N' , - one (T), D̄ + D̄' , R, one (T), R̄)
252
- end
253
-
254
- k -= nb
255
- end
256
- return tril! (Σ̄)
257
- end
258
- end
259
-
260
- function chol_blocked_rev (Σ̄:: AbstractMatrix , L:: AbstractMatrix , nb:: Integer , upper:: Bool )
261
- # Convert to `Matrix`s because blas functions require StridedMatrix input.
262
- return chol_blocked_rev! (Matrix (Σ̄), Matrix (L), nb, upper)
263
- end
0 commit comments