Skip to content

Commit be21539

Browse files
authored
Merge pull request #237 from ReactionMechanismGenerator/fix_adjoint_sens
Fix Adjoint Sensitivities
2 parents a3b22d5 + e80bb70 commit be21539

File tree

3 files changed

+24
-18
lines changed

3 files changed

+24
-18
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
99
Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
1010
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
1111
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
12-
DiffEqSensitivity = "41bf760c-e81c-5289-8e54-58b1f1f8abe2"
1312
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
1413
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1514
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
@@ -28,6 +27,7 @@ QuartzImageIO = "dca85d43-d64c-5e67-8c65-017450d5d020"
2827
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2928
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
3029
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
30+
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
3131
SmoothingSplines = "102930c3-cf33-599f-b3b1-9a29a5acab30"
3232
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
3333
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
@@ -47,7 +47,7 @@ Colors = "0.11,0.12"
4747
Conda = "1"
4848
CSV = "0"
4949
DataFrames = "1"
50-
DiffEqSensitivity = "6"
50+
SciMLSensitivity = "^7"
5151
ForwardDiff = "0.10"
5252
Images = "0.24"
5353
IncompleteLU = "0.2"

src/Simulation.jl

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
using SciMLBase
2-
import SciMLBase: AbstractODESolution, HermiteInterpolation
3-
using DiffEqSensitivity
2+
import SciMLBase: AbstractODESolution, HermiteInterpolation, AbstractDiffEqInterpolation
3+
using SciMLSensitivity
44
using ForwardDiff
55
using PreallocationTools
6+
using LinearAlgebra
67

78
abstract type AbstractSimulation end
89
export AbstractSimulation
@@ -459,8 +460,8 @@ By default uses the InterpolatingAdjoint algorithm with vector Jacobian products
459460
this assumes no changes in code branching during simulation, if that were to become no longer true, the Tracker
460461
based alternative algorithm is slower, but avoids this concern.
461462
"""
462-
function getadjointsensitivities(bsol::Q, target::String, solver::W; sensalg::W2=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(false)),
463-
abstol::Float64=1e-6, reltol::Float64=1e-3, normalize=true, kwargs...) where {Q,W,W2}
463+
function getadjointsensitivities(bsol::Simulation, target::String, solver; sensalg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(false)),
464+
abstol::Float64=1e-6, reltol::Float64=1e-3, normalize=true, kwargs...)
464465
@assert target in bsol.names || target in ["T", "V", "P", "mass"]
465466

466467
pethane = 160
@@ -538,15 +539,19 @@ function getadjointsensitivities(bsol::Q, target::String, solver::W; sensalg::W2
538539

539540
if length(bsol.domain.p) <= pethane
540541
if target in ["T", "V", "P", "mass"] || !isempty(bsol.interfaces)
541-
du0, dpadj = adjoint_sensitivities(bsol.sol, solver, g, nothing, (dgdu, dgdp); sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...)
542+
du0, dpadj = adjoint_sensitivities(bsol.sol, solver; g=g, dgdu_continuous=dgdu,
543+
dgdp_continuous=dgdp, sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...)
542544
else
543-
du0, dpadj = adjoint_sensitivities(bsol.sol, solver, sensg, nothing, (dsensgdu, dsensgdp); sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...)
545+
du0, dpadj = adjoint_sensitivities(bsol.sol, solver; g=sensg, dgdu_continuous=dsensgdu,
546+
dgdp_continuous=dsensgdp, sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...)
544547
end
545548
else
546549
if target in ["T", "V", "P", "mass"] || !isempty(bsol.interfaces)
547-
du0, dpadj = adjoint_sensitivities(bsol.sol, solver, g, nothing, (dgdurevdiff, dgdprevdiff); sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...)
550+
du0, dpadj = adjoint_sensitivities(bsol.sol, solver; g=g, dgdu_continuous=gdurevdiff,
551+
dgdp_continuous=dgdprevdiff, sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...)
548552
else
549-
du0, dpadj = adjoint_sensitivities(bsol.sol, solver, sensg, nothing, (dsensgdurevdiff, dsensgdprevdiff); sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...)
553+
du0, dpadj = adjoint_sensitivities(bsol.sol, solver; g=sensg, dgdu_continuous=dsensgdurevdiff,
554+
dgdp_continuous=dsensgdprevdiff, sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...)
550555
end
551556
end
552557
if normalize
@@ -557,7 +562,7 @@ function getadjointsensitivities(bsol::Q, target::String, solver::W; sensalg::W2
557562
dpadj ./= bsol.sol(bsol.sol.t[end])[ind]
558563
end
559564
end
560-
return dpadj
565+
return dpadj::LinearAlgebra.Adjoint{Float64, Vector{Float64}}
561566
end
562567

563568
function getadjointsensitivities(syssim::Q, bsol::W3, target::String, solver::W; sensalg::W2=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(false)),
@@ -591,7 +596,8 @@ function getadjointsensitivities(syssim::Q, bsol::W3, target::String, solver::W;
591596
end
592597
dgdu(out, y, p, t) = ForwardDiff.gradient!(out, y -> g(y, p, t), y)
593598
dgdp(out, y, p, t) = ForwardDiff.gradient!(out, p -> g(y, p, t), p)
594-
du0, dpadj = adjoint_sensitivities(syssim.sol, solver, g, nothing, (dgdu, dgdp); sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...)
599+
du0, dpadj = adjoint_sensitivities(syssim.sol, solver; g=g, dgdu_continuous=dgdu, dgdp_continuous=dgdp,
600+
sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...)
595601
if normalize
596602
for domain in domains
597603
dpadj[domain.parameterindexes[1]+length(domain.phase.species):domain.parameterindexes[2]] .*= syssim.p[domain.parameterindexes[1]+length(domain.phase.species):domain.parameterindexes[2]]
@@ -600,7 +606,7 @@ function getadjointsensitivities(syssim::Q, bsol::W3, target::String, solver::W;
600606
dpadj ./= bsol.sol(bsol.sol.t[end])[ind]
601607
end
602608
end
603-
return dpadj
609+
return dpadj::LinearAlgebra.Adjoint{Float64, Vector{Float64}}
604610
end
605611
export getadjointsensitivities
606612

src/TestReactors.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ jp=jacobianpforwarddiff(y,p,t,domain,[],nothing);
220220
@test all((abs.(jpa.-jp) .> 1e-4.*abs.(jp).+1e-16).==false)
221221

222222
#sensitivities
223-
dps = getadjointsensitivities(sim,"H2",CVODE_BDF(linear_solver=:GMRES);sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)),abstol=1e-16,reltol=1e-6)
223+
dps = getadjointsensitivities(sim,"H2",CVODE_BDF(linear_solver=:GMRES);sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)),abstol=1e-12,reltol=1e-6)
224224
react2 = Reactor(domain,y0,(0.0,150.11094);p=p,forwardsensitivities=true)
225225
sol2 = solve(react2.ode,CVODE_BDF(linear_solver=:GMRES),abstol=1e-21,reltol=1e-7); #solve the ode associated with the reactor
226226
sim2 = Simulation(sol2,domain)
@@ -272,7 +272,7 @@ end;
272272
@test all((abs.(jpa.-jp) .> 1e-4.*abs.(jp).+1e-16).==false)
273273

274274
#sensitivities
275-
dps = getadjointsensitivities(sim,"H2",CVODE_BDF(linear_solver=:GMRES);sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)),abstol=1e-16,reltol=1e-6)
275+
dps = getadjointsensitivities(sim,"H2",CVODE_BDF(linear_solver=:GMRES);sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)),abstol=1e-12,reltol=1e-6)
276276
react2 = Reactor(domain,y0,(0.0,150.11094),interfaces;p=p,forwardsensitivities=true)
277277
sol2 = solve(react2.ode,CVODE_BDF(linear_solver=:GMRES),abstol=1e-21,reltol=1e-7); #solve the ode associated with the reactor
278278
sim2 = Simulation(sol2,domain,interfaces)
@@ -316,7 +316,7 @@ jp=jacobianpforwarddiff(y,p,t,domain,[],nothing);
316316
react = Reactor(domain,y0,(0.0,0.02),p=p) #Create the reactor object
317317
sol = solve(react.ode,CVODE_BDF(),abstol=1e-20,reltol=1e-12); #solve the ode associated with the reactor
318318
sim = Simulation(sol,domain)
319-
dps = getadjointsensitivities(sim,"H2",CVODE_BDF();sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(false)),abstol=1e-16,reltol=1e-6)
319+
dps = getadjointsensitivities(sim,"H2",CVODE_BDF();sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(false)),abstol=1e-12,reltol=1e-6)
320320
react2 = Reactor(domain,y0,(0.0,0.02);p=p,forwardsensitivities=true)
321321
sol2 = solve(react2.ode,CVODE_BDF(),abstol=1e-16,reltol=1e-6); #solve the ode associated with the reactor
322322
sim2 = Simulation(sol2,domain)
@@ -488,8 +488,8 @@ end;
488488
@test sol(t)[1:length(spcs)] solV(t)[1:end-2] rtol=1e-5
489489
@test sol(t)[length(spcs)+1:end-4] solV(t)[1:end-2] rtol=1e-5
490490

491-
dpsV = getadjointsensitivities(simV,"H2",CVODE_BDF();sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(false)),abstol=1e-16,reltol=1e-6)
492-
dps = getadjointsensitivities(sysim,sysim.sims[1],"H2",CVODE_BDF();sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(false)),abstol=1e-16,reltol=1e-6)
491+
dpsV = getadjointsensitivities(simV,"H2",CVODE_BDF();sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(false)),abstol=1e-12,reltol=1e-6)
492+
dps = getadjointsensitivities(sysim,sysim.sims[1],"H2",CVODE_BDF();sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(false)),abstol=1e-12,reltol=1e-6)
493493
@test dpsV dps rtol=1e-4
494494
end;
495495

0 commit comments

Comments
 (0)