|
| 1 | +using LinearAlgebra: checksquare |
| 2 | +using LinearAlgebra.BLAS: gemv, gemv!, gemm!, trsm!, axpy!, ger! |
| 3 | + |
1 | 4 | #####
|
2 | 5 | ##### `svd`
|
3 | 6 | #####
|
@@ -82,3 +85,187 @@ function _add!(X::AbstractMatrix{T}, Y::AbstractMatrix{T}) where T<:Real
|
82 | 85 | end
|
83 | 86 | X
|
84 | 87 | end
|
| 88 | + |
| 89 | +##### |
| 90 | +##### `cholesky` |
| 91 | +##### |
| 92 | + |
| 93 | +function rrule(::typeof(cholesky), X::AbstractMatrix{<:Real}) |
| 94 | + F = cholesky(X) |
| 95 | + ∂X = Rule(Ȳ->chol_blocked_rev(Matrix(Ȳ), Matrix(F.U), 25, true)) |
| 96 | + return F, ∂X |
| 97 | +end |
| 98 | + |
| 99 | +function rrule(::typeof(getproperty), F::Cholesky, x::Symbol) |
| 100 | + if x === :U |
| 101 | + if F.uplo === 'U' |
| 102 | + ∂F = Ȳ->UpperTriangular(Ȳ) |
| 103 | + else |
| 104 | + ∂F = Ȳ->LowerTriangular(Ȳ') |
| 105 | + end |
| 106 | + elseif x === :L |
| 107 | + if F.uplo === 'L' |
| 108 | + ∂F = Ȳ->LowerTriangular(Ȳ) |
| 109 | + else |
| 110 | + ∂F = Ȳ->UpperTriangular(Ȳ') |
| 111 | + end |
| 112 | + end |
| 113 | + return getproperty(F, x), (Rule(∂F), DNERule()) |
| 114 | +end |
| 115 | + |
| 116 | +# See "Differentiation of the Cholesky decomposition" (Murray 2016), pages 5-9 in particular, |
| 117 | +# for derivations. Here we're implementing the algorithms and their transposes. |
| 118 | + |
| 119 | +""" |
| 120 | + level2partition(A::AbstractMatrix, j::Integer, upper::Bool) |
| 121 | +
|
| 122 | +Returns views to various bits of the lower triangle of `A` according to the |
| 123 | +`level2partition` procedure defined in [1] if `upper` is `false`. If `upper` is `true` then |
| 124 | +the transposed views are returned from the upper triangle of `A`. |
| 125 | +
|
| 126 | +[1]: "Differentiation of the Cholesky decomposition", Murray 2016 |
| 127 | +""" |
| 128 | +function level2partition(A::AbstractMatrix, j::Integer, upper::Bool) |
| 129 | + n = checksquare(A) |
| 130 | + @boundscheck checkbounds(1:n, j) |
| 131 | + if upper |
| 132 | + r = view(A, 1:j-1, j) |
| 133 | + d = view(A, j, j) |
| 134 | + B = view(A, 1:j-1, j+1:n) |
| 135 | + c = view(A, j, j+1:n) |
| 136 | + else |
| 137 | + r = view(A, j, 1:j-1) |
| 138 | + d = view(A, j, j) |
| 139 | + B = view(A, j+1:n, 1:j-1) |
| 140 | + c = view(A, j+1:n, j) |
| 141 | + end |
| 142 | + return r, d, B, c |
| 143 | +end |
| 144 | + |
| 145 | +""" |
| 146 | + level3partition(A::AbstractMatrix, j::Integer, k::Integer, upper::Bool) |
| 147 | +
|
| 148 | +Returns views to various bits of the lower triangle of `A` according to the |
| 149 | +`level3partition` procedure defined in [1] if `upper` is `false`. If `upper` is `true` then |
| 150 | +the transposed views are returned from the upper triangle of `A`. |
| 151 | +
|
| 152 | +[1]: "Differentiation of the Cholesky decomposition", Murray 2016 |
| 153 | +""" |
| 154 | +function level3partition(A::AbstractMatrix, j::Integer, k::Integer, upper::Bool) |
| 155 | + n = checksquare(A) |
| 156 | + @boundscheck checkbounds(1:n, j) |
| 157 | + if upper |
| 158 | + R = view(A, 1:j-1, j:k) |
| 159 | + D = view(A, j:k, j:k) |
| 160 | + B = view(A, 1:j-1, k+1:n) |
| 161 | + C = view(A, j:k, k+1:n) |
| 162 | + else |
| 163 | + R = view(A, j:k, 1:j-1) |
| 164 | + D = view(A, j:k, j:k) |
| 165 | + B = view(A, k+1:n, 1:j-1) |
| 166 | + C = view(A, k+1:n, j:k) |
| 167 | + end |
| 168 | + return R, D, B, C |
| 169 | +end |
| 170 | + |
| 171 | +""" |
| 172 | + chol_unblocked_rev!(Ā::AbstractMatrix, L::AbstractMatrix, upper::Bool) |
| 173 | +
|
| 174 | +Compute the reverse-mode sensitivities of the Cholesky factorization in an unblocked manner. |
| 175 | +If `upper` is `false`, then the sensitivites are computed from and stored in the lower triangle |
| 176 | +of `Ā` and `L` respectively. If `upper` is `true` then they are computed and stored in the |
| 177 | +upper triangles. If at input `upper` is `false` and `tril(Ā) = L̄`, at output |
| 178 | +`tril(Ā) = tril(Σ̄)`, where `Σ = LLᵀ`. Analogously, if at input `upper` is `true` and |
| 179 | +`triu(Ā) = triu(Ū)`, at output `triu(Ā) = triu(Σ̄)` where `Σ = UᵀU`. |
| 180 | +""" |
| 181 | +function chol_unblocked_rev!(Σ̄::AbstractMatrix{T}, L::AbstractMatrix{T}, upper::Bool) where T<:Real |
| 182 | + n = checksquare(Σ̄) |
| 183 | + j = n |
| 184 | + @inbounds for _ in 1:n |
| 185 | + r, d, B, c = level2partition(L, j, upper) |
| 186 | + r̄, d̄, B̄, c̄ = level2partition(Σ̄, j, upper) |
| 187 | + |
| 188 | + # d̄ <- d̄ - c'c̄ / d. |
| 189 | + d̄[1] -= dot(c, c̄) / d[1] |
| 190 | + |
| 191 | + # [d̄ c̄'] <- [d̄ c̄'] / d. |
| 192 | + d̄ ./= d |
| 193 | + c̄ ./= d |
| 194 | + |
| 195 | + # r̄ <- r̄ - [d̄ c̄'] [r' B']'. |
| 196 | + r̄ = axpy!(-Σ̄[j,j], r, r̄) |
| 197 | + r̄ = gemv!(upper ? 'n' : 'T', -one(T), B, c̄, one(T), r̄) |
| 198 | + |
| 199 | + # B̄ <- B̄ - c̄ r. |
| 200 | + B̄ = upper ? ger!(-one(T), r, c̄, B̄) : ger!(-one(T), c̄, r, B̄) |
| 201 | + d̄ ./= 2 |
| 202 | + j -= 1 |
| 203 | + end |
| 204 | + return (upper ? triu! : tril!)(Σ̄) |
| 205 | +end |
| 206 | + |
| 207 | +function chol_unblocked_rev(Σ̄::AbstractMatrix, L::AbstractMatrix, upper::Bool) |
| 208 | + return chol_unblocked_rev!(copy(Σ̄), L, upper) |
| 209 | +end |
| 210 | + |
| 211 | +""" |
| 212 | + chol_blocked_rev!(Σ̄::AbstractMatrix, L::AbstractMatrix, nb::Integer, upper::Bool) |
| 213 | +
|
| 214 | +Compute the sensitivities of the Cholesky factorization using a blocked, cache-friendly |
| 215 | +procedure. `Σ̄` are the sensitivities of `L`, and will be transformed into the sensitivities |
| 216 | +of `Σ`, where `Σ = LLᵀ`. `nb` is the block size to use. If the upper triangle has been used |
| 217 | +to represent the factorization, that is `Σ = UᵀU` where `U := Lᵀ`, then this should be |
| 218 | +indicated by passing `upper = true`. |
| 219 | +""" |
| 220 | +function chol_blocked_rev!(Σ̄::AbstractMatrix{T}, L::AbstractMatrix{T}, nb::Integer, upper::Bool) where T<:Real |
| 221 | + n = checksquare(Σ̄) |
| 222 | + tmp = Matrix{T}(undef, nb, nb) |
| 223 | + k = n |
| 224 | + if upper |
| 225 | + @inbounds for _ in 1:nb:n |
| 226 | + j = max(1, k - nb + 1) |
| 227 | + R, D, B, C = level3partition(L, j, k, true) |
| 228 | + R̄, D̄, B̄, C̄ = level3partition(Σ̄, j, k, true) |
| 229 | + |
| 230 | + C̄ = trsm!('L', 'U', 'N', 'N', one(T), D, C̄) |
| 231 | + gemm!('N', 'N', -one(T), R, C̄, one(T), B̄) |
| 232 | + gemm!('N', 'T', -one(T), C, C̄, one(T), D̄) |
| 233 | + chol_unblocked_rev!(D̄, D, true) |
| 234 | + gemm!('N', 'T', -one(T), B, C̄, one(T), R̄) |
| 235 | + if size(D̄, 1) == nb |
| 236 | + tmp = axpy!(one(T), D̄, transpose!(tmp, D̄)) |
| 237 | + gemm!('N', 'N', -one(T), R, tmp, one(T), R̄) |
| 238 | + else |
| 239 | + gemm!('N', 'N', -one(T), R, D̄ + D̄', one(T), R̄) |
| 240 | + end |
| 241 | + |
| 242 | + k -= nb |
| 243 | + end |
| 244 | + return triu!(Σ̄) |
| 245 | + else |
| 246 | + @inbounds for _ in 1:nb:n |
| 247 | + j = max(1, k - nb + 1) |
| 248 | + R, D, B, C = level3partition(L, j, k, false) |
| 249 | + R̄, D̄, B̄, C̄ = level3partition(Σ̄, j, k, false) |
| 250 | + |
| 251 | + C̄ = trsm!('R', 'L', 'N', 'N', one(T), D, C̄) |
| 252 | + gemm!('N', 'N', -one(T), C̄, R, one(T), B̄) |
| 253 | + gemm!('T', 'N', -one(T), C̄, C, one(T), D̄) |
| 254 | + chol_unblocked_rev!(D̄, D, false) |
| 255 | + gemm!('T', 'N', -one(T), C̄, B, one(T), R̄) |
| 256 | + if size(D̄, 1) == nb |
| 257 | + tmp = axpy!(one(T), D̄, transpose!(tmp, D̄)) |
| 258 | + gemm!('N', 'N', -one(T), tmp, R, one(T), R̄) |
| 259 | + else |
| 260 | + gemm!('N', 'N', -one(T), D̄ + D̄', R, one(T), R̄) |
| 261 | + end |
| 262 | + |
| 263 | + k -= nb |
| 264 | + end |
| 265 | + return tril!(Σ̄) |
| 266 | + end |
| 267 | +end |
| 268 | + |
| 269 | +function chol_blocked_rev(Σ̄::AbstractMatrix, L::AbstractMatrix, nb::Integer, upper::Bool) |
| 270 | + return chol_blocked_rev!(copy(Σ̄), L, nb, upper) |
| 271 | +end |
0 commit comments