Skip to content

Commit 97cbe16

Browse files
committed
adapt adjoint_sensitivity syntax
1 parent 29f6718 commit 97cbe16

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

src/Simulation.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using SciMLBase
2-
import SciMLBase: AbstractODESolution, HermiteInterpolation
3-
using DiffEqSensitivity
2+
import SciMLBase: AbstractODESolution, HermiteInterpolation, AbstractDiffEqInterpolation
3+
using SciMLSensitivity
44
using ForwardDiff
55
using PreallocationTools
66

@@ -538,15 +538,19 @@ function getadjointsensitivities(bsol::Q, target::String, solver::W; sensalg::W2
538538

539539
if length(bsol.domain.p) <= pethane
540540
if target in ["T", "V", "P", "mass"] || !isempty(bsol.interfaces)
541-
du0, dpadj = adjoint_sensitivities(bsol.sol, solver, g, nothing, (dgdu, dgdp); sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...)
541+
du0, dpadj = adjoint_sensitivities(bsol.sol, solver; g=g, dgdu_continuous=dgdu,
542+
dgdp_continuous=dgdp, sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...)
542543
else
543-
du0, dpadj = adjoint_sensitivities(bsol.sol, solver, sensg, nothing, (dsensgdu, dsensgdp); sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...)
544+
du0, dpadj = adjoint_sensitivities(bsol.sol, solver; g=sensg, dgdu_continuous=dsensgdu,
545+
dgdp_continuous=dsensgdp, sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...)
544546
end
545547
else
546548
if target in ["T", "V", "P", "mass"] || !isempty(bsol.interfaces)
547-
du0, dpadj = adjoint_sensitivities(bsol.sol, solver, g, nothing, (dgdurevdiff, dgdprevdiff); sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...)
549+
du0, dpadj = adjoint_sensitivities(bsol.sol, solver; g=g, dgdu_continuous=gdurevdiff,
550+
dgdp_continuous=dgdprevdiff, sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...)
548551
else
549-
du0, dpadj = adjoint_sensitivities(bsol.sol, solver, sensg, nothing, (dsensgdurevdiff, dsensgdprevdiff); sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...)
552+
du0, dpadj = adjoint_sensitivities(bsol.sol, solver; g=sensg, dgdu_continuous=dsensgdurevdiff,
553+
dgdp_continuous=dsensgdprevdiff, sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...)
550554
end
551555
end
552556
if normalize
@@ -591,7 +595,8 @@ function getadjointsensitivities(syssim::Q, bsol::W3, target::String, solver::W;
591595
end
592596
dgdu(out, y, p, t) = ForwardDiff.gradient!(out, y -> g(y, p, t), y)
593597
dgdp(out, y, p, t) = ForwardDiff.gradient!(out, p -> g(y, p, t), p)
594-
du0, dpadj = adjoint_sensitivities(syssim.sol, solver, g, nothing, (dgdu, dgdp); sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...)
598+
du0, dpadj = adjoint_sensitivities(syssim.sol, solver; g=g, dgdu_continuous=dgdu, dgdp_continuous=dgdp,
599+
sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...)
595600
if normalize
596601
for domain in domains
597602
dpadj[domain.parameterindexes[1]+length(domain.phase.species):domain.parameterindexes[2]] .*= syssim.p[domain.parameterindexes[1]+length(domain.phase.species):domain.parameterindexes[2]]

0 commit comments

Comments
 (0)