Skip to content

Commit d5f015b

Browse files
committed
Move sparse logabsdet from structured to sparsematrix
1 parent 378aea2 commit d5f015b

File tree

4 files changed

+48
-47
lines changed

4 files changed

+48
-47
lines changed

src/rulesets/LinearAlgebra/structured.jl

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -268,40 +268,3 @@ function rrule(::typeof(logdet), X::Union{Diagonal, AbstractTriangular})
268268
end
269269
return y, logdet_pullback
270270
end
271-
272-
function rrule(::typeof(logabsdet), x::SparseMatrixCSC)
273-
F = cholesky(x)
274-
L, D, U, P = SparseInverseSubset.get_ldup(F)
275-
Ω = logabsdet(D)
276-
function logabsdet_pullback(ΔΩ)
277-
(Δy, Δsigny) = ΔΩ
278-
(_, signy) = Ω
279-
f = signy' * Δsigny
280-
imagf = f - real(f)
281-
g = real(Δy) + imagf
282-
Z, P = sparseinv(F, depermute=true)
283-
∂x = g * Z'
284-
return (NoTangent(), ∂x)
285-
end
286-
return Ω, logabsdet_pullback
287-
end
288-
289-
function rrule(::typeof(logdet), x::SparseMatrixCSC)
290-
Ω = logdet(x)
291-
function logdet_pullback(ΔΩ)
292-
Z, p = sparseinv(x, depermute=true)
293-
∂x = ΔΩ * Z'
294-
return (NoTangent(), ∂x)
295-
end
296-
return Ω, logdet_pullback
297-
end
298-
299-
function rrule(::typeof(det), x::SparseMatrixCSC)
300-
Ω = det(x)
301-
function det_pullback(ΔΩ)
302-
Z, _ = sparseinv(x, depermute=true)
303-
∂x = Z' * dot(Ω, ΔΩ)
304-
return (NoTangent(), ∂x)
305-
end
306-
return Ω, det_pullback
307-
end

src/rulesets/SparseArrays/sparsematrix.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,41 @@ function rrule(::typeof(findnz), v::AbstractSparseVector)
4949

5050
return (I, V), findnz_pullback
5151
end
52+
53+
54+
function rrule(::typeof(logabsdet), x::SparseMatrixCSC)
55+
F = cholesky(x)
56+
L, D, U, P = SparseInverseSubset.get_ldup(F)
57+
Ω = logabsdet(D)
58+
function logabsdet_pullback(ΔΩ)
59+
(Δy, Δsigny) = ΔΩ
60+
(_, signy) = Ω
61+
f = signy' * Δsigny
62+
imagf = f - real(f)
63+
g = real(Δy) + imagf
64+
Z, P = sparseinv(F, depermute=true)
65+
∂x = g * Z'
66+
return (NoTangent(), ∂x)
67+
end
68+
return Ω, logabsdet_pullback
69+
end
70+
71+
function rrule(::typeof(logdet), x::SparseMatrixCSC)
72+
Ω = logdet(x)
73+
function logdet_pullback(ΔΩ)
74+
Z, p = sparseinv(x, depermute=true)
75+
∂x = ΔΩ * Z'
76+
return (NoTangent(), ∂x)
77+
end
78+
return Ω, logdet_pullback
79+
end
80+
81+
function rrule(::typeof(det), x::SparseMatrixCSC)
82+
Ω = det(x)
83+
function det_pullback(ΔΩ)
84+
Z, _ = sparseinv(x, depermute=true)
85+
∂x = Z' * dot(Ω, ΔΩ)
86+
return (NoTangent(), ∂x)
87+
end
88+
return Ω, det_pullback
89+
end

test/rulesets/LinearAlgebra/structured.jl

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -161,14 +161,4 @@
161161
end
162162
end
163163
end
164-
165-
@testset "[log[abs[det]]] SparseMatrixCSC" begin
166-
ii = 1:5
167-
jj = 1:5
168-
x = ones(5)
169-
A = sparse(ii, jj, x)
170-
test_rrule(logabsdet, A)
171-
test_rrule(logdet, A)
172-
test_rrule(det, A)
173-
end
174164
end

test/rulesets/SparseArrays/sparsematrix.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,13 @@ end
3333
= rand!(similar(V))
3434
test_rrule(findnz, v dv, output_tangent=(zeros(length(I)), V̄))
3535
end
36+
37+
@testset "[log[abs[det]]] SparseMatrixCSC" begin
38+
ii = 1:5
39+
jj = 1:5
40+
x = ones(5)
41+
A = sparse(ii, jj, x)
42+
test_rrule(logabsdet, A)
43+
test_rrule(logdet, A)
44+
test_rrule(det, A)
45+
end

0 commit comments

Comments
 (0)