Skip to content

Commit a8a58ab

Browse files
authored
Merge pull request #2730 from oscardssmith/os/fix-FBDF-ad
fix autodiff for FBDF
2 parents 88c1407 + 7a09794 commit a8a58ab

File tree

3 files changed

+40
-16
lines changed

3 files changed

+40
-16
lines changed

lib/OrdinaryDiffEqBDF/src/bdf_caches.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ function alg_cache(alg::QNDF{MO}, u, rate_prototype, ::Type{uEltypeNoUnits},
360360
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits
361361
} where {MO}
362362
max_order = MO
363-
γ, c = one(uEltypeNoUnits), 1
363+
γ, c = 1//1, 1
364364
nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits,
365365
uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false))
366366
dtprev = one(dt)
@@ -426,7 +426,7 @@ function alg_cache(alg::QNDF{MO}, u, rate_prototype, ::Type{uEltypeNoUnits},
426426
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits
427427
} where {MO}
428428
max_order = MO
429-
γ, c = one(eltype(alg.kappa)), 1
429+
γ, c = 1//1, 1
430430
nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits,
431431
uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true))
432432
fsalfirst = zero(rate_prototype)
@@ -541,7 +541,7 @@ function alg_cache(alg::FBDF{MO}, u, rate_prototype, ::Type{uEltypeNoUnits},
541541
dt, reltol, p, calck,
542542
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits
543543
} where {MO}
544-
γ, c = one(uEltypeNoUnits), 1
544+
γ, c = 1//1, 1
545545
max_order = MO
546546
nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits,
547547
uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false))
@@ -616,7 +616,7 @@ function alg_cache(alg::FBDF{MO}, u, rate_prototype, ::Type{uEltypeNoUnits},
616616
dt, reltol, p, calck,
617617
::Val{true}) where {MO, uEltypeNoUnits, uBottomEltypeNoUnits,
618618
tTypeNoUnits}
619-
γ, c = one(uEltypeNoUnits), 1
619+
γ, c = 1//1, 1
620620
fsalfirst = zero(rate_prototype)
621621
max_order = MO
622622
nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits,

lib/OrdinaryDiffEqBDF/test/bdf_regression_tests.jl

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
using OrdinaryDiffEqBDF, Test
1+
using OrdinaryDiffEqBDF, ForwardDiff, Test
22

3-
foop = (u, p, t) -> u
4-
proboop = ODEProblem(foop, ones(2), (0.0, 1000.0))
3+
foop = (u, p, t) -> u * p
4+
proboop = ODEProblem(foop, ones(2), (0.0, 1000.0), 1.0)
55

6-
fiip = (du, u, p, t) -> du .= u
7-
probiip = ODEProblem(fiip, ones(2), (0.0, 1000.0))
6+
fiip = (du, u, p, t) -> du .= u .* p
7+
probiip = ODEProblem(fiip, ones(2), (0.0, 1000.0), 1.0)
88

99
@testset "FBDF reinit" begin
1010
for prob in [proboop, probiip]
@@ -18,3 +18,19 @@ probiip = ODEProblem(fiip, ones(2), (0.0, 1000.0))
1818
@test integ.sol.t[end] >= 700
1919
end
2020
end
21+
22+
function ad_helper(alg, prob)
23+
function costoop(p)
24+
_oprob = remake(prob; p)
25+
sol = solve(_oprob, alg, saveat=1:10)
26+
return sum(sol)
27+
end
28+
end
29+
30+
@testset "parameter autodiff" begin
31+
for prob in [proboop, probiip]
32+
for alg in [FBDF(), QNDF()]
33+
ForwardDiff.derivative(ad_helper(alg, prob), 1.0)
34+
end
35+
end
36+
end

lib/OrdinaryDiffEqBDF/test/dae_ad_tests.jl

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,16 @@ function f(du, u, p, t)
1414
+p[1] * u[1] - p[2] * u[2]^2 - p[3] * u[2] * u[3] - du[2],
1515
u[1] + u[2] + u[3] - 1.0]
1616
end
17+
function f_ode(du, u, p, t)
18+
du .= [-p[1] * u[1] + p[3] * u[2] * u[3],
19+
+p[1] * u[1] - p[2] * u[2]^2 - p[3] * u[2] * u[3],
20+
u[1] + u[2] + u[3] - 1.0]
21+
end
22+
function f_ode(u, p, t)
23+
[-p[1] * u[1] + p[3] * u[2] * u[3],
24+
+p[1] * u[1] - p[2] * u[2]^2 - p[3] * u[2] * u[3],
25+
u[1] + u[2] + u[3] - 1.0]
26+
end
1727
p = [0.04, 3e7, 1e4]
1828
u₀ = [1.0, 0, 0]
1929
du₀ = [-0.04, 0.04, 0.0]
@@ -22,16 +32,18 @@ differential_vars = [true, true, false]
2232
M = Diagonal([1.0, 1.0, 0.0])
2333
prob = DAEProblem(f, du₀, u₀, tspan, p, differential_vars = differential_vars)
2434
prob_oop = DAEProblem{false}(f, du₀, u₀, tspan, p, differential_vars = differential_vars)
25-
f_mm = ODEFunction{true, SciMLBase.AutoSpecialize}(f, mass_matrix = M)
35+
f_mm = ODEFunction{true}(f_ode, mass_matrix = M)
2636
prob_mm = ODEProblem(f_mm, u₀, tspan, p)
37+
f_mm_oop = ODEFunction{false}(f_ode, mass_matrix = M)
38+
prob_mm_oop = ODEProblem(f_mm_oop, u₀, tspan, p)
2739
@test_broken sol1 = @inferred solve(prob, DFBDF(autodiff=afd_cs3), dt = 1e-5, abstol = 1e-8, reltol = 1e-8)
2840
@test_broken sol2 = @inferred solve(prob_oop, DFBDF(autodiff=afd_cs3), dt = 1e-5, abstol = 1e-8, reltol = 1e-8)
2941
@test_broken sol3 = @inferred solve(prob_mm, FBDF(autodiff=afd_cs3), dt = 1e-5, abstol = 1e-8, reltol = 1e-8)
3042

3143
# These tests flex differentiation of the solver and through the initialization
3244
# To only test the solver part and isolate potential issues, set the initialization to consistent
3345
@testset "Inplace: $(isinplace(_prob)), DAEProblem: $(_prob isa DAEProblem), BrownBasic: $(initalg isa BrownFullBasicInit), Autodiff: $autodiff" for _prob in [
34-
prob, prob_oop, prob_mm],
46+
prob, prob_oop, prob_mm, prob_mm_oop],
3547
initalg in [BrownFullBasicInit(), ShampineCollocationInit()], autodiff in [afd_cs3, AutoFiniteDiff()]
3648

3749
alg = (_prob isa DAEProblem) ? DFBDF(; autodiff) : FBDF(; autodiff)
@@ -40,9 +52,5 @@ prob_mm = ODEProblem(f_mm, u₀, tspan, p)
4052
reltol = 1e-14, initializealg = initalg)
4153
sum(sol)
4254
end
43-
if _prob isa DAEProblem
44-
@test ForwardDiff.gradient(f, [0.04, 3e7, 1e4])[0, 0, 0] atol=1e-8
45-
else
46-
@test_broken ForwardDiff.gradient(f, [0.04, 3e7, 1e4])[0, 0, 0] atol=1e-8
47-
end
55+
@test ForwardDiff.gradient(f, [0.04, 3e7, 1e4])[0, 0, 0] atol=1e-8
4856
end

0 commit comments

Comments
 (0)