Skip to content

Commit 7df39de

Browse files
committed
import logabsdetogabsdet(F::UmfpackLU) from future
This method is required for the sparse logabsdet rrules, but was not included in Julia prior to v1.7.
1 parent 72e6a83 commit 7df39de

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1616
SparseInverseSubset = "dc90abb0-5640-4711-901d-7e5b23a2fada"
1717
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1818
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
19+
SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
1920

2021
[compat]
2122
Adapt = "3.4.0"

src/rulesets/SparseArrays/sparsematrix.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,40 @@ function rrule(::typeof(findnz), v::AbstractSparseVector)
5050
return (I, V), findnz_pullback
5151
end
5252

53+
if VERSION < v"1.7"
54+
#=
55+
This method for `logabsdet(F::UmfpackLU)` is required to calculate the (log)determinants
56+
of sparse matrices, but was not defined prior to Julia v1.7. In order fo the rrules
57+
for the determinants of sparse matrices below to work, they need to be able to
58+
compute the primals as well, so this import from the future is included. For more
59+
recent versions of Julia, this definition lives in:
60+
julia/stdlib/SuiteSparse/src/umfpack.jl
61+
=#
62+
using SuiteSparse.UMFPACK: _signperm, UmfpackLU
63+
64+
for itype in (:Int32, :Int64)
65+
@eval begin
66+
function LinearAlgebra.logabsdet(F::UmfpackLU{T, $itype}) where {T<:Union{Float64,ComplexF64}}
67+
n = checksquare(F)
68+
issuccess(F) || return log(zero(real(T))), zero(T)
69+
U = F.U
70+
Rs = F.Rs
71+
p = F.p
72+
q = F.q
73+
s = _signperm(p)*_signperm(q)*one(real(T))
74+
P = one(T)
75+
abs_det = zero(real(T))
76+
@inbounds for i in 1:n
77+
dg_ii = U[i, i] / Rs[i]
78+
P *= sign(dg_ii)
79+
abs_det += log(abs(dg_ii))
80+
end
81+
return abs_det, s * P
82+
end
83+
end
84+
end
85+
end
86+
5387

5488
function rrule(::typeof(logabsdet), x::SparseMatrixCSC)
5589
F = cholesky(x)

0 commit comments

Comments
 (0)