|
1 | 1 | using SciMLBase
|
2 |
| -import SciMLBase: AbstractODESolution, HermiteInterpolation |
3 |
| -using DiffEqSensitivity |
| 2 | +import SciMLBase: AbstractODESolution, HermiteInterpolation, AbstractDiffEqInterpolation |
| 3 | +using SciMLSensitivity |
4 | 4 | using ForwardDiff
|
5 | 5 | using PreallocationTools
|
6 | 6 |
|
@@ -538,15 +538,19 @@ function getadjointsensitivities(bsol::Q, target::String, solver::W; sensalg::W2
|
538 | 538 |
|
539 | 539 | if length(bsol.domain.p) <= pethane
|
540 | 540 | if target in ["T", "V", "P", "mass"] || !isempty(bsol.interfaces)
|
541 |
| - du0, dpadj = adjoint_sensitivities(bsol.sol, solver, g, nothing, (dgdu, dgdp); sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...) |
| 541 | + du0, dpadj = adjoint_sensitivities(bsol.sol, solver; g=g, dgdu_continuous=dgdu, |
| 542 | + dgdp_continuous=dgdp, sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...) |
542 | 543 | else
|
543 |
| - du0, dpadj = adjoint_sensitivities(bsol.sol, solver, sensg, nothing, (dsensgdu, dsensgdp); sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...) |
| 544 | + du0, dpadj = adjoint_sensitivities(bsol.sol, solver; g=sensg, dgdu_continuous=dsensgdu, |
| 545 | + dgdp_continuous=dsensgdp, sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...) |
544 | 546 | end
|
545 | 547 | else
|
546 | 548 | if target in ["T", "V", "P", "mass"] || !isempty(bsol.interfaces)
|
547 |
| - du0, dpadj = adjoint_sensitivities(bsol.sol, solver, g, nothing, (dgdurevdiff, dgdprevdiff); sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...) |
| 549 | + du0, dpadj = adjoint_sensitivities(bsol.sol, solver; g=g, dgdu_continuous=gdurevdiff, |
| 550 | + dgdp_continuous=dgdprevdiff, sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...) |
548 | 551 | else
|
549 |
| - du0, dpadj = adjoint_sensitivities(bsol.sol, solver, sensg, nothing, (dsensgdurevdiff, dsensgdprevdiff); sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...) |
| 552 | + du0, dpadj = adjoint_sensitivities(bsol.sol, solver; g=sensg, dgdu_continuous=dsensgdurevdiff, |
| 553 | + dgdp_continuous=dsensgdprevdiff, sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...) |
550 | 554 | end
|
551 | 555 | end
|
552 | 556 | if normalize
|
@@ -591,7 +595,8 @@ function getadjointsensitivities(syssim::Q, bsol::W3, target::String, solver::W;
|
591 | 595 | end
|
592 | 596 | dgdu(out, y, p, t) = ForwardDiff.gradient!(out, y -> g(y, p, t), y)
|
593 | 597 | dgdp(out, y, p, t) = ForwardDiff.gradient!(out, p -> g(y, p, t), p)
|
594 |
| - du0, dpadj = adjoint_sensitivities(syssim.sol, solver, g, nothing, (dgdu, dgdp); sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...) |
| 598 | + du0, dpadj = adjoint_sensitivities(syssim.sol, solver; g=g, dgdu_continuous=dgdu, dgdp_continuous=dgdp, |
| 599 | + sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...) |
595 | 600 | if normalize
|
596 | 601 | for domain in domains
|
597 | 602 | dpadj[domain.parameterindexes[1]+length(domain.phase.species):domain.parameterindexes[2]] .*= syssim.p[domain.parameterindexes[1]+length(domain.phase.species):domain.parameterindexes[2]]
|
|
0 commit comments