@@ -49,3 +49,91 @@ function rrule(::typeof(findnz), v::AbstractSparseVector)
49
49
50
50
return (I, V), findnz_pullback
51
51
end
52
+
53
+ if VERSION < v " 1.7"
54
+ #=
55
+ The method below for `logabsdet(F::UmfpackLU)` is required to calculate the (log)
56
+ determinants of sparse matrices, but was not defined prior to Julia v1.7. In order
57
+ for the rrules for the determinants of sparse matrices below to work, they need to be
58
+ able to compute the primals as well, so this import from the future is included. For
59
+ more recent versions of Julia, this definition lives in:
60
+ julia/stdlib/SuiteSparse/src/umfpack.jl
61
+ =#
62
+ using SuiteSparse. UMFPACK: UmfpackLU
63
+
64
+ # compute the sign/parity of a permutation
65
+ function _signperm (p)
66
+ n = length (p)
67
+ result = 0
68
+ todo = trues (n)
69
+ while any (todo)
70
+ k = findfirst (todo)
71
+ todo[k] = false
72
+ result += 1 # increment element count
73
+ j = p[k]
74
+ while j != k
75
+ result += 1 # increment element count
76
+ todo[j] = false
77
+ j = p[j]
78
+ end
79
+ result += 1 # increment cycle count
80
+ end
81
+ return ifelse (isodd (result), - 1 , 1 )
82
+ end
83
+
84
+ function LinearAlgebra. logabsdet (F:: UmfpackLU{T, TI} ) where {T<: Union{Float64,ComplexF64} ,TI<: Union{Int32, Int64} }
85
+ n = checksquare (F)
86
+ issuccess (F) || return log (zero (real (T))), zero (T)
87
+ U = F. U
88
+ Rs = F. Rs
89
+ p = F. p
90
+ q = F. q
91
+ s = _signperm (p)* _signperm (q)* one (real (T))
92
+ P = one (T)
93
+ abs_det = zero (real (T))
94
+ @inbounds for i in 1 : n
95
+ dg_ii = U[i, i] / Rs[i]
96
+ P *= sign (dg_ii)
97
+ abs_det += log (abs (dg_ii))
98
+ end
99
+ return abs_det, s * P
100
+ end
101
+ end
102
+
103
+
104
+ function rrule (:: typeof (logabsdet), x:: SparseMatrixCSC )
105
+ F = cholesky (x)
106
+ L, D, U, P = SparseInverseSubset. get_ldup (F)
107
+ Ω = logabsdet (D)
108
+ function logabsdet_pullback (ΔΩ)
109
+ (Δy, Δsigny) = ΔΩ
110
+ (_, signy) = Ω
111
+ f = signy' * Δsigny
112
+ imagf = f - real (f)
113
+ g = real (Δy) + imagf
114
+ Z, P = sparseinv (F, depermute= true )
115
+ ∂x = g * Z'
116
+ return (NoTangent (), ∂x)
117
+ end
118
+ return Ω, logabsdet_pullback
119
+ end
120
+
121
+ function rrule (:: typeof (logdet), x:: SparseMatrixCSC )
122
+ Ω = logdet (x)
123
+ function logdet_pullback (ΔΩ)
124
+ Z, p = sparseinv (x, depermute= true )
125
+ ∂x = ΔΩ * Z'
126
+ return (NoTangent (), ∂x)
127
+ end
128
+ return Ω, logdet_pullback
129
+ end
130
+
131
+ function rrule (:: typeof (det), x:: SparseMatrixCSC )
132
+ Ω = det (x)
133
+ function det_pullback (ΔΩ)
134
+ Z, _ = sparseinv (x, depermute= true )
135
+ ∂x = Z' * dot (Ω, ΔΩ)
136
+ return (NoTangent (), ∂x)
137
+ end
138
+ return Ω, det_pullback
139
+ end
0 commit comments