Skip to content

Commit 7a44f20

Browse files
authored
Merge pull request #772 from ElOceanografo/nogpl
Fix bugs with SparseInverseSubset on non-GPL Julia builds
2 parents c2fd16f + cfc7060 commit 7a44f20

File tree

3 files changed

+94
-88
lines changed

3 files changed

+94
-88
lines changed

src/rulesets/LinearAlgebra/structured.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# Structured matrices
22
using LinearAlgebra: AbstractTriangular
3-
using SparseInverseSubset
43

54
# Matrix wrapper types that we know are square and are thus potentially invertible. For
65
# these we can use simpler definitions for `/` and `\`.

src/rulesets/SparseArrays/sparsematrix.jl

Lines changed: 84 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -50,94 +50,99 @@ function rrule(::typeof(findnz), v::AbstractSparseVector)
5050
return (I, V), findnz_pullback
5151
end
5252

53-
if VERSION < v"1.7"
54-
#=
55-
The method below for `logabsdet(F::UmfpackLU)` is required to calculate the (log)
56-
determinants of sparse matrices, but was not defined prior to Julia v1.7. In order
57-
for the rrules for the determinants of sparse matrices below to work, they need to be
58-
able to compute the primals as well, so this import from the future is included. For
59-
more recent versions of Julia, this definition lives in:
60-
julia/stdlib/SuiteSparse/src/umfpack.jl
61-
=#
62-
using SuiteSparse.UMFPACK: UmfpackLU
63-
64-
# compute the sign/parity of a permutation
65-
function _signperm(p)
66-
n = length(p)
67-
result = 0
68-
todo = trues(n)
69-
while any(todo)
70-
k = findfirst(todo)
71-
todo[k] = false
72-
result += 1 # increment element count
73-
j = p[k]
74-
while j != k
53+
if Base.USE_GPL_LIBS # Don't define rrules for sparse determinants if we don't have CHOLMOD from SuiteSparse.jl
54+
55+
if VERSION < v"1.7"
56+
#=
57+
The method below for `logabsdet(F::UmfpackLU)` is required to calculate the (log)
58+
determinants of sparse matrices, but was not defined prior to Julia v1.7. In order
59+
for the rrules for the determinants of sparse matrices below to work, they need to be
60+
able to compute the primals as well, so this import from the future is included. For
61+
more recent versions of Julia, this definition lives in:
62+
julia/stdlib/SuiteSparse/src/umfpack.jl
63+
=#
64+
using SuiteSparse.UMFPACK: UmfpackLU
65+
66+
# compute the sign/parity of a permutation
67+
function _signperm(p)
68+
n = length(p)
69+
result = 0
70+
todo = trues(n)
71+
while any(todo)
72+
k = findfirst(todo)
73+
todo[k] = false
7574
result += 1 # increment element count
76-
todo[j] = false
77-
j = p[j]
75+
j = p[k]
76+
while j != k
77+
result += 1 # increment element count
78+
todo[j] = false
79+
j = p[j]
80+
end
81+
result += 1 # increment cycle count
7882
end
79-
result += 1 # increment cycle count
83+
return ifelse(isodd(result), -1, 1)
8084
end
81-
return ifelse(isodd(result), -1, 1)
82-
end
8385

84-
function LinearAlgebra.logabsdet(F::UmfpackLU{T, TI}) where {T<:Union{Float64,ComplexF64},TI<:Union{Int32, Int64}}
85-
n = checksquare(F)
86-
issuccess(F) || return log(zero(real(T))), zero(T)
87-
U = F.U
88-
Rs = F.Rs
89-
p = F.p
90-
q = F.q
91-
s = _signperm(p)*_signperm(q)*one(real(T))
92-
P = one(T)
93-
abs_det = zero(real(T))
94-
@inbounds for i in 1:n
95-
dg_ii = U[i, i] / Rs[i]
96-
P *= sign(dg_ii)
97-
abs_det += log(abs(dg_ii))
86+
using SparseInverseSubset
87+
88+
function LinearAlgebra.logabsdet(F::UmfpackLU{T, TI}) where {T<:Union{Float64,ComplexF64},TI<:Union{Int32, Int64}}
89+
n = checksquare(F)
90+
issuccess(F) || return log(zero(real(T))), zero(T)
91+
U = F.U
92+
Rs = F.Rs
93+
p = F.p
94+
q = F.q
95+
s = _signperm(p)*_signperm(q)*one(real(T))
96+
P = one(T)
97+
abs_det = zero(real(T))
98+
@inbounds for i in 1:n
99+
dg_ii = U[i, i] / Rs[i]
100+
P *= sign(dg_ii)
101+
abs_det += log(abs(dg_ii))
102+
end
103+
return abs_det, s * P
98104
end
99-
return abs_det, s * P
100105
end
101-
end
102-
103-
104-
function rrule(::typeof(logabsdet), x::SparseMatrixCSC)
105-
F = cholesky(x)
106-
L, D, U, P = SparseInverseSubset.get_ldup(F)
107-
Ω = logabsdet(D)
108-
function logabsdet_pullback(ΔΩ)
109-
(Δy, Δsigny) = ΔΩ
110-
(_, signy) = Ω
111-
f = signy' * Δsigny
112-
imagf = f - real(f)
113-
g = real(Δy) + imagf
114-
Z, P = sparseinv(F, depermute=true)
115-
∂x = g * Z'
116-
return (NoTangent(), ∂x)
106+
107+
108+
function rrule(::typeof(logabsdet), x::SparseMatrixCSC)
109+
F = cholesky(x)
110+
L, D, U, P = SparseInverseSubset.get_ldup(F)
111+
Ω = logabsdet(D)
112+
function logabsdet_pullback(ΔΩ)
113+
(Δy, Δsigny) = ΔΩ
114+
(_, signy) = Ω
115+
f = signy' * Δsigny
116+
imagf = f - real(f)
117+
g = real(Δy) + imagf
118+
Z, P = sparseinv(F, depermute=true)
119+
∂x = g * Z'
120+
return (NoTangent(), ∂x)
121+
end
122+
return Ω, logabsdet_pullback
117123
end
118-
return Ω, logabsdet_pullback
119-
end
120-
121-
function rrule(::typeof(logdet), x::SparseMatrixCSC)
122-
Ω = logdet(x)
123-
function logdet_pullback(ΔΩ)
124-
Z, p = sparseinv(x, depermute=true)
125-
∂x = ΔΩ * Z'
126-
return (NoTangent(), ∂x)
124+
125+
function rrule(::typeof(logdet), x::SparseMatrixCSC)
126+
Ω = logdet(x)
127+
function logdet_pullback(ΔΩ)
128+
Z, p = sparseinv(x, depermute=true)
129+
∂x = ΔΩ * Z'
130+
return (NoTangent(), ∂x)
131+
end
132+
return Ω, logdet_pullback
127133
end
128-
return Ω, logdet_pullback
129-
end
130-
131-
function rrule(::typeof(det), x::SparseMatrixCSC)
132-
Ω = det(x)
133-
function det_pullback(ΔΩ)
134-
Z, _ = sparseinv(x, depermute=true)
135-
∂x = Z' * dot(Ω, ΔΩ)
136-
return (NoTangent(), ∂x)
134+
135+
function rrule(::typeof(det), x::SparseMatrixCSC)
136+
Ω = det(x)
137+
function det_pullback(ΔΩ)
138+
Z, _ = sparseinv(x, depermute=true)
139+
∂x = Z' * dot(Ω, ΔΩ)
140+
return (NoTangent(), ∂x)
141+
end
142+
return Ω, det_pullback
137143
end
138-
return Ω, det_pullback
139-
end
140-
144+
145+
end # rrules that depend on CHOLMOD
141146

142147
function rrule(::typeof(spdiagm), m::Integer, n::Integer, kv::Pair{<:Integer,<:AbstractVector}...)
143148

test/rulesets/SparseArrays/sparsematrix.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,14 @@ end
7979
test_rrule(findnz, v dv, output_tangent=(zeros(length(I)), V̄))
8080
end
8181

82-
@testset "[log[abs[det]]] SparseMatrixCSC" begin
83-
ii = [1:5; 2; 4]
84-
jj = [1:5; 4; 2]
85-
x = [ones(5); 0.1; 0.1]
86-
A = sparse(ii, jj, x)
87-
test_rrule(logabsdet, A)
88-
test_rrule(logdet, A)
89-
test_rrule(det, A)
82+
if Base.USE_GPL_LIBS # these rrules don't work without CHOLMOD from SuiteSparse.jl
83+
@testset "[log[abs[det]]] SparseMatrixCSC" begin
84+
ii = [1:5; 2; 4]
85+
jj = [1:5; 4; 2]
86+
x = [ones(5); 0.1; 0.1]
87+
A = sparse(ii, jj, x)
88+
test_rrule(logabsdet, A)
89+
test_rrule(logdet, A)
90+
test_rrule(det, A)
91+
end
9092
end

0 commit comments

Comments
 (0)