Skip to content

Commit 869de2f

Browse files
Merge pull request #2662 from jClugstor/sparse_fixes
Fix sparsity pattern mismatch and test
2 parents 125184e + 4a873d6 commit 869de2f

File tree

4 files changed

+59
-4
lines changed

4 files changed

+59
-4
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
150150
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
151151
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
152152
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
153+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
153154
ElasticArrays = "fdbdab4c-e67f-52f5-8c3f-e7b388dad3d4"
154155
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
155156
IncompleteLU = "40713840-3770-5561-ab4c-a76e7d0d7895"
@@ -175,4 +176,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
175176
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
176177

177178
[targets]
178-
test = ["Calculus", "ComponentArrays", "Symbolics", "AlgebraicMultigrid", "IncompleteLU", "DiffEqCallbacks", "DiffEqDevTools", "ODEProblemLibrary", "ElasticArrays", "InteractiveUtils", "ParameterizedFunctions", "PoissonRandom", "Printf", "Random", "ReverseDiff", "SafeTestsets", "SparseArrays", "Statistics", "StructArrays", "Test", "Unitful", "ModelingToolkit", "Pkg", "NLsolve", "RecursiveFactorization", "Enzyme", "SparseConnectivityTracer", "SparseMatrixColorings"]
179+
test = ["Calculus", "ComponentArrays", "Symbolics", "AlgebraicMultigrid", "IncompleteLU", "DiffEqCallbacks", "DifferentiationInterface", "DiffEqDevTools", "ODEProblemLibrary", "ElasticArrays", "InteractiveUtils", "ParameterizedFunctions", "PoissonRandom", "Printf", "Random", "ReverseDiff", "SafeTestsets", "SparseArrays", "Statistics", "StructArrays", "Test", "Unitful", "ModelingToolkit", "Pkg", "NLsolve", "RecursiveFactorization", "Enzyme", "SparseConnectivityTracer", "SparseMatrixColorings"]

lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ function prepare_user_sparsity(ad_alg, prob)
106106
sparsity = prob.f.sparsity
107107

108108
if !isnothing(sparsity) && !(ad_alg isa AutoSparse)
109-
if sparsity isa SparseMatrixCSC
109+
if sparsity isa SparseMatrixCSC && !SciMLBase.has_jac(prob.f)
110110
if prob.f.mass_matrix isa UniformScaling
111111
idxs = diagind(sparsity)
112112
@. @view(sparsity[idxs]) = 1

lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,21 @@ function calc_J!(J, integrator, cache, next_step::Bool = false)
172172
if DiffEqBase.has_jac(f)
173173
duprev = integrator.duprev
174174
uf = cache.uf
175-
f.jac(J, duprev, uprev, p, uf.α * uf.invγdt, t)
175+
# need to do some jank here to account for sparsity pattern of W
176+
# https://github.com/SciML/OrdinaryDiffEq.jl/issues/2653
177+
178+
# we need to set all nzval to a non-zero number
179+
# otherwise in the following line any zero gets interpreted as a structural zero
180+
if !isnothing(integrator.f.jac_prototype) &&
181+
integrator.f.jac_prototype isa SparseMatrixCSC
182+
183+
integrator.f.jac_prototype.nzval .= true
184+
J .= true .* integrator.f.jac_prototype
185+
J.nzval .= false
186+
f.jac(J, duprev, uprev, p, uf.α * uf.invγdt, t)
187+
else
188+
f.jac(J, duprev, uprev, p, uf.α * uf.invγdt, t)
189+
end
176190
else
177191
@unpack du1, uf, jac_config = cache
178192
# using `dz` as temporary array
@@ -183,7 +197,21 @@ function calc_J!(J, integrator, cache, next_step::Bool = false)
183197
end
184198
else
185199
if DiffEqBase.has_jac(f)
186-
f.jac(J, uprev, p, t)
200+
# need to do some jank here to account for sparsity pattern of W
201+
# https://github.com/SciML/OrdinaryDiffEq.jl/issues/2653
202+
203+
# we need to set all nzval to a non-zero number
204+
# otherwise in the following line any zero gets interpreted as a structural zero
205+
if !isnothing(integrator.f.jac_prototype) &&
206+
integrator.f.jac_prototype isa SparseMatrixCSC
207+
208+
integrator.f.jac_prototype.nzval .= true
209+
J .= true .* integrator.f.jac_prototype
210+
J.nzval .= false
211+
f.jac(J, uprev, p, t)
212+
else
213+
f.jac(J, uprev, p, t)
214+
end
187215
else
188216
@unpack du1, uf, jac_config = cache
189217
uf.f = nlsolve_f(f, alg)

test/interface/sparsediff_tests.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ using OrdinaryDiffEq
33
using SparseArrays
44
using LinearAlgebra
55
using LinearSolve
6+
import DifferentiationInterface as DI
7+
using SparseConnectivityTracer
8+
using SparseMatrixColorings
69
using ADTypes
710
using Enzyme
811

@@ -84,3 +87,26 @@ for f in [f_oop, f_ip]
8487
end
8588
end
8689

90+
# test for https://github.com/SciML/OrdinaryDiffEq.jl/issues/2653#issuecomment-2778430025
91+
92+
using LinearAlgebra, SparseArrays
93+
using OrdinaryDiffEq
94+
95+
function f(du, u, p, t)
96+
du[1] = u[1]
97+
return du
98+
end
99+
100+
function jac(J::SparseMatrixCSC, u, p, t)
101+
@assert nnz(J) == 1 # mirrors the strict behavior of SparseMatrixColorings
102+
nonzeros(J)[1] = 1
103+
return J
104+
end
105+
106+
u0 = ones(10)
107+
jac_prototype = sparse(Diagonal(vcat(1, zeros(9))))
108+
109+
fun = ODEFunction(f; jac, jac_prototype)
110+
prob = ODEProblem(fun, u0, (0.0, 1.0))
111+
@test_nowarn sol = solve(prob, Rodas4(); reltol = 1e-8, abstol = 1e-8)
112+

0 commit comments

Comments
 (0)