Skip to content

Commit 3577abd

Browse files
committed
det, logdet, logabsdet rrules for SparseMatrixCSC
1 parent 11c230c commit 3577abd

File tree

3 files changed

+49
-0
lines changed

3 files changed

+49
-0
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1313
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1414
RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9"
1515
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
16+
SparseInverseSubset = "dc90abb0-5640-4711-901d-7e5b23a2fada"
1617
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1718
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
1819

src/rulesets/LinearAlgebra/structured.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Structured matrices
22
using LinearAlgebra: AbstractTriangular
3+
using SparseInverseSubset
34

45
# Matrix wrapper types that we know are square and are thus potentially invertible. For
56
# these we can use simpler definitions for `/` and `\`.
@@ -267,3 +268,40 @@ function rrule(::typeof(logdet), X::Union{Diagonal, AbstractTriangular})
267268
end
268269
return y, logdet_pullback
269270
end
271+
272+
function rrule(::typeof(logabsdet), x::SparseMatrixCSC)
273+
F = cholesky(x)
274+
L, D, U, P = SparseInverseSubset.get_ldup(F)
275+
Ω = logabsdet(D)
276+
function logabsdet_pullback(ΔΩ)
277+
(Δy, Δsigny) = ΔΩ
278+
(_, signy) = Ω
279+
f = signy' * Δsigny
280+
imagf = f - real(f)
281+
g = real(Δy) + imagf
282+
Z, P = sparseinv(F, depermute=true)
283+
∂x = g * Z'
284+
return (NoTangent(), ∂x)
285+
end
286+
return Ω, logabsdet_pullback
287+
end
288+
289+
function rrule(::typeof(logdet), x::SparseMatrixCSC)
290+
Ω = logdet(x)
291+
function logdet_pullback(ΔΩ)
292+
Z, p = sparseinv(x, depermute=true)
293+
∂x = x isa Number ? ΔΩ / x' : ΔΩ * Z'
294+
return (NoTangent(), ∂x)
295+
end
296+
return Ω, logdet_pullback
297+
end
298+
299+
function rrule(::typeof(det), x::SparseMatrixCSC)
300+
Ω = det(x)
301+
function det_pullback(ΔΩ)
302+
Z, _ = sparseinv(x, depermute=true)
303+
∂x = x isa Number ? ΔΩ : Z' * dot(Ω, ΔΩ)
304+
return (NoTangent(), ∂x)
305+
end
306+
return Ω, det_pullback
307+
end

test/rulesets/LinearAlgebra/structured.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,4 +161,14 @@
161161
end
162162
end
163163
end
164+
165+
@testset "[log[abs[det]]] SparseMatrixCSC" begin
166+
ii = 1:5
167+
jj = 1:5
168+
x = ones(5)
169+
A = sparse(ii, jj, x)
170+
test_rrule(logabsdet, A)
171+
test_rrule(logdet, A)
172+
test_rrule(det, A)
173+
end
164174
end

0 commit comments

Comments
 (0)