@@ -50,94 +50,99 @@ function rrule(::typeof(findnz), v::AbstractSparseVector)
50
50
return (I, V), findnz_pullback
51
51
end
52
52
53
- if VERSION < v " 1.7"
54
- #=
55
- The method below for `logabsdet(F::UmfpackLU)` is required to calculate the (log)
56
- determinants of sparse matrices, but was not defined prior to Julia v1.7. In order
57
- for the rrules for the determinants of sparse matrices below to work, they need to be
58
- able to compute the primals as well, so this import from the future is included. For
59
- more recent versions of Julia, this definition lives in:
60
- julia/stdlib/SuiteSparse/src/umfpack.jl
61
- =#
62
- using SuiteSparse. UMFPACK: UmfpackLU
63
-
64
- # compute the sign/parity of a permutation
65
- function _signperm (p)
66
- n = length (p)
67
- result = 0
68
- todo = trues (n)
69
- while any (todo)
70
- k = findfirst (todo)
71
- todo[k] = false
72
- result += 1 # increment element count
73
- j = p[k]
74
- while j != k
53
+ if Base. USE_GPL_LIBS # Don't define rrules for sparse determinants if we don't have CHOLMOD from SuiteSparse.jl
54
+
55
+ if VERSION < v " 1.7"
56
+ #=
57
+ The method below for `logabsdet(F::UmfpackLU)` is required to calculate the (log)
58
+ determinants of sparse matrices, but was not defined prior to Julia v1.7. In order
59
+ for the rrules for the determinants of sparse matrices below to work, they need to be
60
+ able to compute the primals as well, so this import from the future is included. For
61
+ more recent versions of Julia, this definition lives in:
62
+ julia/stdlib/SuiteSparse/src/umfpack.jl
63
+ =#
64
+ using SuiteSparse. UMFPACK: UmfpackLU
65
+
66
+ # compute the sign/parity of a permutation
67
+ function _signperm (p)
68
+ n = length (p)
69
+ result = 0
70
+ todo = trues (n)
71
+ while any (todo)
72
+ k = findfirst (todo)
73
+ todo[k] = false
75
74
result += 1 # increment element count
76
- todo[j] = false
77
- j = p[j]
75
+ j = p[k]
76
+ while j != k
77
+ result += 1 # increment element count
78
+ todo[j] = false
79
+ j = p[j]
80
+ end
81
+ result += 1 # increment cycle count
78
82
end
79
- result += 1 # increment cycle count
83
+ return ifelse ( isodd ( result), - 1 , 1 )
80
84
end
81
- return ifelse (isodd (result), - 1 , 1 )
82
- end
83
85
84
- function LinearAlgebra. logabsdet (F:: UmfpackLU{T, TI} ) where {T<: Union{Float64,ComplexF64} ,TI<: Union{Int32, Int64} }
85
- n = checksquare (F)
86
- issuccess (F) || return log (zero (real (T))), zero (T)
87
- U = F. U
88
- Rs = F. Rs
89
- p = F. p
90
- q = F. q
91
- s = _signperm (p)* _signperm (q)* one (real (T))
92
- P = one (T)
93
- abs_det = zero (real (T))
94
- @inbounds for i in 1 : n
95
- dg_ii = U[i, i] / Rs[i]
96
- P *= sign (dg_ii)
97
- abs_det += log (abs (dg_ii))
86
+ using SparseInverseSubset
87
+
88
+ function LinearAlgebra. logabsdet (F:: UmfpackLU{T, TI} ) where {T<: Union{Float64,ComplexF64} ,TI<: Union{Int32, Int64} }
89
+ n = checksquare (F)
90
+ issuccess (F) || return log (zero (real (T))), zero (T)
91
+ U = F. U
92
+ Rs = F. Rs
93
+ p = F. p
94
+ q = F. q
95
+ s = _signperm (p)* _signperm (q)* one (real (T))
96
+ P = one (T)
97
+ abs_det = zero (real (T))
98
+ @inbounds for i in 1 : n
99
+ dg_ii = U[i, i] / Rs[i]
100
+ P *= sign (dg_ii)
101
+ abs_det += log (abs (dg_ii))
102
+ end
103
+ return abs_det, s * P
98
104
end
99
- return abs_det, s * P
100
105
end
101
- end
102
-
103
-
104
- function rrule (:: typeof (logabsdet), x:: SparseMatrixCSC )
105
- F = cholesky (x)
106
- L, D, U, P = SparseInverseSubset. get_ldup (F)
107
- Ω = logabsdet (D)
108
- function logabsdet_pullback (ΔΩ)
109
- (Δy, Δsigny) = ΔΩ
110
- (_, signy) = Ω
111
- f = signy' * Δsigny
112
- imagf = f - real (f)
113
- g = real (Δy) + imagf
114
- Z, P = sparseinv (F, depermute= true )
115
- ∂x = g * Z'
116
- return (NoTangent (), ∂x)
106
+
107
+
108
+ function rrule (:: typeof (logabsdet), x:: SparseMatrixCSC )
109
+ F = cholesky (x)
110
+ L, D, U, P = SparseInverseSubset. get_ldup (F)
111
+ Ω = logabsdet (D)
112
+ function logabsdet_pullback (ΔΩ)
113
+ (Δy, Δsigny) = ΔΩ
114
+ (_, signy) = Ω
115
+ f = signy' * Δsigny
116
+ imagf = f - real (f)
117
+ g = real (Δy) + imagf
118
+ Z, P = sparseinv (F, depermute= true )
119
+ ∂x = g * Z'
120
+ return (NoTangent (), ∂x)
121
+ end
122
+ return Ω, logabsdet_pullback
117
123
end
118
- return Ω, logabsdet_pullback
119
- end
120
-
121
- function rrule ( :: typeof (logdet), x :: SparseMatrixCSC )
122
- Ω = logdet (x )
123
- function logdet_pullback (ΔΩ)
124
- Z, p = sparseinv (x, depermute = true )
125
- ∂x = ΔΩ * Z '
126
- return ( NoTangent (), ∂x)
124
+
125
+ function rrule ( :: typeof (logdet), x :: SparseMatrixCSC )
126
+ Ω = logdet (x)
127
+ function logdet_pullback (ΔΩ )
128
+ Z, p = sparseinv (x, depermute = true )
129
+ ∂x = ΔΩ * Z '
130
+ return ( NoTangent (), ∂x )
131
+ end
132
+ return Ω, logdet_pullback
127
133
end
128
- return Ω, logdet_pullback
129
- end
130
-
131
- function rrule ( :: typeof (det), x :: SparseMatrixCSC )
132
- Ω = det (x )
133
- function det_pullback ( ΔΩ)
134
- Z, _ = sparseinv (x, depermute = true )
135
- ∂x = Z ' * dot (Ω, ΔΩ)
136
- return ( NoTangent (), ∂x)
134
+
135
+ function rrule ( :: typeof (det), x :: SparseMatrixCSC )
136
+ Ω = det (x)
137
+ function det_pullback (ΔΩ )
138
+ Z, _ = sparseinv (x, depermute = true )
139
+ ∂x = Z ' * dot (Ω, ΔΩ)
140
+ return ( NoTangent (), ∂x )
141
+ end
142
+ return Ω, det_pullback
137
143
end
138
- return Ω, det_pullback
139
- end
140
-
144
+
145
+ end # rrules that depend on CHOLMOD
141
146
142
147
function rrule (:: typeof (spdiagm), m:: Integer , n:: Integer , kv:: Pair{<:Integer,<:AbstractVector} ...)
143
148
0 commit comments