Skip to content

Commit 7a9feab

Browse files
authored
Merge pull request #730 from ElOceanografo/sparsedet
det, logdet, and logabsdet rrules for SparseMatrixCSC
2 parents df672c3 + 85a83be commit 7a9feab

File tree

4 files changed

+101
-0
lines changed

4 files changed

+101
-0
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1313
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1414
RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9"
1515
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
16+
SparseInverseSubset = "dc90abb0-5640-4711-901d-7e5b23a2fada"
1617
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1718
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
19+
SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
1820

1921
[compat]
2022
Adapt = "3.4.0"

src/rulesets/LinearAlgebra/structured.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Structured matrices
22
using LinearAlgebra: AbstractTriangular
3+
using SparseInverseSubset
34

45
# Matrix wrapper types that we know are square and are thus potentially invertible. For
56
# these we can use simpler definitions for `/` and `\`.

src/rulesets/SparseArrays/sparsematrix.jl

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,91 @@ function rrule(::typeof(findnz), v::AbstractSparseVector)
4949

5050
return (I, V), findnz_pullback
5151
end
52+
53+
if VERSION < v"1.7"
54+
#=
55+
The method below for `logabsdet(F::UmfpackLU)` is required to calculate the (log)
56+
determinants of sparse matrices, but was not defined prior to Julia v1.7. In order
57+
for the rrules for the determinants of sparse matrices below to work, they need to be
58+
able to compute the primals as well, so this import from the future is included. For
59+
more recent versions of Julia, this definition lives in:
60+
julia/stdlib/SuiteSparse/src/umfpack.jl
61+
=#
62+
using SuiteSparse.UMFPACK: UmfpackLU
63+
64+
# compute the sign/parity of a permutation
65+
function _signperm(p)
66+
n = length(p)
67+
result = 0
68+
todo = trues(n)
69+
while any(todo)
70+
k = findfirst(todo)
71+
todo[k] = false
72+
result += 1 # increment element count
73+
j = p[k]
74+
while j != k
75+
result += 1 # increment element count
76+
todo[j] = false
77+
j = p[j]
78+
end
79+
result += 1 # increment cycle count
80+
end
81+
return ifelse(isodd(result), -1, 1)
82+
end
83+
84+
function LinearAlgebra.logabsdet(F::UmfpackLU{T, TI}) where {T<:Union{Float64,ComplexF64},TI<:Union{Int32, Int64}}
85+
n = checksquare(F)
86+
issuccess(F) || return log(zero(real(T))), zero(T)
87+
U = F.U
88+
Rs = F.Rs
89+
p = F.p
90+
q = F.q
91+
s = _signperm(p)*_signperm(q)*one(real(T))
92+
P = one(T)
93+
abs_det = zero(real(T))
94+
@inbounds for i in 1:n
95+
dg_ii = U[i, i] / Rs[i]
96+
P *= sign(dg_ii)
97+
abs_det += log(abs(dg_ii))
98+
end
99+
return abs_det, s * P
100+
end
101+
end
102+
103+
104+
function rrule(::typeof(logabsdet), x::SparseMatrixCSC)
105+
F = cholesky(x)
106+
L, D, U, P = SparseInverseSubset.get_ldup(F)
107+
Ω = logabsdet(D)
108+
function logabsdet_pullback(ΔΩ)
109+
(Δy, Δsigny) = ΔΩ
110+
(_, signy) = Ω
111+
f = signy' * Δsigny
112+
imagf = f - real(f)
113+
g = real(Δy) + imagf
114+
Z, P = sparseinv(F, depermute=true)
115+
∂x = g * Z'
116+
return (NoTangent(), ∂x)
117+
end
118+
return Ω, logabsdet_pullback
119+
end
120+
121+
function rrule(::typeof(logdet), x::SparseMatrixCSC)
122+
Ω = logdet(x)
123+
function logdet_pullback(ΔΩ)
124+
Z, p = sparseinv(x, depermute=true)
125+
∂x = ΔΩ * Z'
126+
return (NoTangent(), ∂x)
127+
end
128+
return Ω, logdet_pullback
129+
end
130+
131+
function rrule(::typeof(det), x::SparseMatrixCSC)
132+
Ω = det(x)
133+
function det_pullback(ΔΩ)
134+
Z, _ = sparseinv(x, depermute=true)
135+
∂x = Z' * dot(Ω, ΔΩ)
136+
return (NoTangent(), ∂x)
137+
end
138+
return Ω, det_pullback
139+
end

test/rulesets/SparseArrays/sparsematrix.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,13 @@ end
3333
= rand!(similar(V))
3434
test_rrule(findnz, v dv, output_tangent=(zeros(length(I)), V̄))
3535
end
36+
37+
@testset "[log[abs[det]]] SparseMatrixCSC" begin
38+
ii = [1:5; 2; 4]
39+
jj = [1:5; 4; 2]
40+
x = [ones(5); 0.1; 0.1]
41+
A = sparse(ii, jj, x)
42+
test_rrule(logabsdet, A)
43+
test_rrule(logdet, A)
44+
test_rrule(det, A)
45+
end

0 commit comments

Comments
 (0)