Skip to content

Commit 37311f4

Browse files
Merge pull request #2729 from SciML/nonlinearsolve_inference
Fix inference with NonlinearSolveAlg
2 parents 7b7dc6b + 7718e80 commit 37311f4

File tree

5 files changed

+47
-25
lines changed

5 files changed

+47
-25
lines changed

lib/OrdinaryDiffEqBDF/Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
6363
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
6464
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
6565
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
66+
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
6667
ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5"
6768
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
6869
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
@@ -72,4 +73,4 @@ JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
7273
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
7374

7475
[targets]
75-
test = ["DiffEqDevTools", "ForwardDiff", "Random", "SafeTestsets", "Test", "ODEProblemLibrary", "StaticArrays", "Enzyme", "LinearSolve", "JET", "Aqua"]
76+
test = ["DiffEqDevTools", "ForwardDiff", "Random", "SafeTestsets", "Test", "ODEProblemLibrary", "StaticArrays", "Enzyme", "LinearSolve", "JET", "Aqua", "NonlinearSolve"]

lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ import OrdinaryDiffEqCore
3232
using OrdinaryDiffEqDifferentiation: UJacobianWrapper
3333
using OrdinaryDiffEqNonlinearSolve: NLNewton, du_alias_or_new, build_nlsolver,
3434
nlsolve!, nlsolvefail, isnewton, markfirststage!,
35-
set_new_W!, DIRK, compute_step!, COEFFICIENT_MULTISTEP
35+
set_new_W!, DIRK, compute_step!, COEFFICIENT_MULTISTEP,
36+
NonlinearSolveAlg
3637
import ADTypes
3738
import ADTypes: AutoForwardDiff, AutoFiniteDiff, AbstractADType
3839

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using OrdinaryDiffEqBDF, NonlinearSolve, Test
2+
3+
prob = ODEProblem((du,u,p,t) -> du .= u, zeros(1), (0.0,1.0))
4+
nlalg = FBDF(autodiff=false, nlsolve = OrdinaryDiffEqBDF.NonlinearSolveAlg(TrustRegion(autodiff = AutoFiniteDiff())))
5+
basicalg = FBDF(autodiff=false)
6+
basicalgad = FBDF()
7+
8+
nlsolver = @inferred OrdinaryDiffEqBDF.build_nlsolver(basicalg, prob.u0, prob.u0, prob.p, 0.0, 0.0, prob.f, prob.u0, Float64,
9+
Float64, Float64, 0.0, 0.0, Val(true))
10+
nlsolver = @inferred OrdinaryDiffEqBDF.build_nlsolver(nlalg, prob.u0, prob.u0, prob.p, 0.0, 0.0, prob.f, prob.u0, Float64,
11+
Float64, Float64, 0.0, 0.0, Val(true))
12+
nlsolver = @test_throws Any @inferred OrdinaryDiffEqBDF.build_nlsolver(basicalgad, prob.u0, prob.u0, prob.p, 0.0, 0.0, prob.f, prob.u0, Float64,
13+
Float64, Float64, 0.0, 0.0, Val(true))

lib/OrdinaryDiffEqBDF/test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using SafeTestsets
55
@time @safetestset "DAE Event Tests" include("dae_event.jl")
66
@time @safetestset "DAE Initialization Tests" include("dae_initialization_tests.jl")
77

8+
@time @safetestset "BDF Inference Tests" include("inference_tests.jl")
89
@time @safetestset "BDF Convergence Tests" include("bdf_convergence_tests.jl")
910
@time @safetestset "BDF Regression Tests" include("bdf_regression_tests.jl")
1011

lib/OrdinaryDiffEqNonlinearSolve/src/utils.jl

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,17 @@ function build_nlsolver(alg, u, uprev, p, t, dt, f::F, rate_prototype,
153153
uBottomEltypeNoUnits, tTypeNoUnits, γ, c, α, iip)
154154
end
155155

156+
function daenlf(ztmp, z, p)
157+
tmp, ustep, γ, α, tstep, k, invγdt, _p, dt, f = p
158+
_compute_rhs!(tmp, ztmp, ustep, γ, α, tstep, k, invγdt, _p, dt, f, z)[1]
159+
end
160+
161+
function odenlf(ztmp, z, p)
162+
tmp, ustep, γ, α, tstep, k, invγdt, method, _p, dt, f = p
163+
_compute_rhs!(
164+
tmp, ztmp, ustep, γ, α, tstep, k, invγdt, method, _p, dt, f, z)[1]
165+
end
166+
156167
function build_nlsolver(
157168
alg, nlalg::Union{NLFunctional, NLAnderson, NLNewton, NonlinearSolveAlg},
158169
u, uprev, p, t, dt,
@@ -215,19 +226,11 @@ function build_nlsolver(
215226
if nlalg isa NonlinearSolveAlg
216227
α = tTypeNoUnits(α)
217228
dt = tTypeNoUnits(dt)
218-
if isdae
219-
nlf = (ztmp, z, p) -> begin
220-
tmp, ustep, γ, α, tstep, k, invγdt, _p, dt, f = p
221-
_compute_rhs!(tmp, ztmp, ustep, γ, α, tstep, k, invγdt, _p, dt, f, z)[1]
222-
end
223-
nlp_params = (tmp, ustep, γ, α, tstep, k, invγdt, p, dt, f)
229+
nlf = isdae ? daenlf : odenlf
230+
nlp_params = if isdae
231+
(tmp, ustep, γ, α, tstep, k, invγdt, p, dt, f)
224232
else
225-
nlf = (ztmp, z, p) -> begin
226-
tmp, ustep, γ, α, tstep, k, invγdt, method, _p, dt, f = p
227-
_compute_rhs!(
228-
tmp, ztmp, ustep, γ, α, tstep, k, invγdt, method, _p, dt, f, z)[1]
229-
end
230-
nlp_params = (tmp, ustep, γ, α, tstep, k, invγdt, DIRK, p, dt, f)
233+
(tmp, ustep, γ, α, tstep, k, invγdt, DIRK, p, dt, f)
231234
end
232235
prob = NonlinearProblem(NonlinearFunction{true}(nlf), ztmp, nlp_params)
233236
cache = init(prob, nlalg.alg)
@@ -262,6 +265,16 @@ function build_nlsolver(
262265
Divergence, nlcache)
263266
end
264267

268+
function oopdaenlf(z, p)
269+
tmp, α, tstep, invγdt, _p, dt, uprev, f = p
270+
_compute_rhs(tmp, α, tstep, invγdt, p, dt, uprev, f, z)[1]
271+
end
272+
273+
function oopodenlf(z, p)
274+
tmp, γ, α, tstep, invγdt, method, _p, dt, f = p
275+
_compute_rhs(tmp, γ, α, tstep, invγdt, method, _p, dt, f, z)[1]
276+
end
277+
265278
function build_nlsolver(
266279
alg, nlalg::Union{NLFunctional, NLAnderson, NLNewton, NonlinearSolveAlg},
267280
u, uprev, p,
@@ -300,18 +313,11 @@ function build_nlsolver(
300313
if nlalg isa NonlinearSolveAlg
301314
α = tTypeNoUnits(α)
302315
dt = tTypeNoUnits(dt)
303-
if isdae
304-
nlf = (z, p) -> begin
305-
tmp, α, tstep, invγdt, _p, dt, uprev, f = p
306-
_compute_rhs(tmp, α, tstep, invγdt, p, dt, uprev, f, z)[1]
307-
end
308-
nlp_params = (tmp, α, tstep, invγdt, p, dt, uprev, f)
316+
nlf = isdae ? oopdaenlf : oopodenlf
317+
nlp_params = if isdae
318+
(tmp, α, tstep, invγdt, p, dt, uprev, f)
309319
else
310-
nlf = (z, p) -> begin
311-
tmp, γ, α, tstep, invγdt, method, _p, dt, f = p
312-
_compute_rhs(tmp, γ, α, tstep, invγdt, method, _p, dt, f, z)[1]
313-
end
314-
nlp_params = (tmp, γ, α, tstep, invγdt, DIRK, p, dt, f)
320+
(tmp, γ, α, tstep, invγdt, DIRK, p, dt, f)
315321
end
316322
prob = NonlinearProblem(NonlinearFunction{false}(nlf), copy(ztmp), nlp_params)
317323
cache = init(prob, nlalg.alg)

0 commit comments

Comments
 (0)