1
1
using SciMLBase
2
- import SciMLBase: AbstractODESolution, HermiteInterpolation
3
- using DiffEqSensitivity
2
+ import SciMLBase: AbstractODESolution, HermiteInterpolation, AbstractDiffEqInterpolation
3
+ using SciMLSensitivity
4
4
using ForwardDiff
5
5
using PreallocationTools
6
+ using LinearAlgebra
6
7
7
8
abstract type AbstractSimulation end
8
9
export AbstractSimulation
@@ -459,8 +460,8 @@ By default uses the InterpolatingAdjoint algorithm with vector Jacobian products
459
460
this assumes no changes in code branching during simulation, if that were to become no longer true, the Tracker
460
461
based alternative algorithm is slower, but avoids this concern.
461
462
"""
462
- function getadjointsensitivities (bsol:: Q , target:: String , solver:: W ; sensalg:: W2 = InterpolatingAdjoint (autojacvec= ReverseDiffVJP (false )),
463
- abstol:: Float64 = 1e-6 , reltol:: Float64 = 1e-3 , normalize= true , kwargs... ) where {Q,W,W2}
463
+ function getadjointsensitivities (bsol:: Simulation , target:: String , solver; sensalg= InterpolatingAdjoint (autojacvec= ReverseDiffVJP (false )),
464
+ abstol:: Float64 = 1e-6 , reltol:: Float64 = 1e-3 , normalize= true , kwargs... )
464
465
@assert target in bsol. names || target in [" T" , " V" , " P" , " mass" ]
465
466
466
467
pethane = 160
@@ -538,15 +539,19 @@ function getadjointsensitivities(bsol::Q, target::String, solver::W; sensalg::W2
538
539
539
540
if length (bsol. domain. p) <= pethane
540
541
if target in [" T" , " V" , " P" , " mass" ] || ! isempty (bsol. interfaces)
541
- du0, dpadj = adjoint_sensitivities (bsol. sol, solver, g, nothing , (dgdu, dgdp); sensealg= sensalg, abstol= abstol, reltol= reltol, kwargs... )
542
+ du0, dpadj = adjoint_sensitivities (bsol. sol, solver; g= g, dgdu_continuous= dgdu,
543
+ dgdp_continuous= dgdp, sensealg= sensalg, abstol= abstol, reltol= reltol, kwargs... )
542
544
else
543
- du0, dpadj = adjoint_sensitivities (bsol. sol, solver, sensg, nothing , (dsensgdu, dsensgdp); sensealg= sensalg, abstol= abstol, reltol= reltol, kwargs... )
545
+ du0, dpadj = adjoint_sensitivities (bsol. sol, solver; g= sensg, dgdu_continuous= dsensgdu,
546
+ dgdp_continuous= dsensgdp, sensealg= sensalg, abstol= abstol, reltol= reltol, kwargs... )
544
547
end
545
548
else
546
549
if target in [" T" , " V" , " P" , " mass" ] || ! isempty (bsol. interfaces)
547
- du0, dpadj = adjoint_sensitivities (bsol. sol, solver, g, nothing , (dgdurevdiff, dgdprevdiff); sensealg= sensalg, abstol= abstol, reltol= reltol, kwargs... )
550
+ du0, dpadj = adjoint_sensitivities (bsol. sol, solver; g= g, dgdu_continuous= gdurevdiff,
551
+ dgdp_continuous= dgdprevdiff, sensealg= sensalg, abstol= abstol, reltol= reltol, kwargs... )
548
552
else
549
- du0, dpadj = adjoint_sensitivities (bsol. sol, solver, sensg, nothing , (dsensgdurevdiff, dsensgdprevdiff); sensealg= sensalg, abstol= abstol, reltol= reltol, kwargs... )
553
+ du0, dpadj = adjoint_sensitivities (bsol. sol, solver; g= sensg, dgdu_continuous= dsensgdurevdiff,
554
+ dgdp_continuous= dsensgdprevdiff, sensealg= sensalg, abstol= abstol, reltol= reltol, kwargs... )
550
555
end
551
556
end
552
557
if normalize
@@ -557,7 +562,7 @@ function getadjointsensitivities(bsol::Q, target::String, solver::W; sensalg::W2
557
562
dpadj ./= bsol. sol (bsol. sol. t[end ])[ind]
558
563
end
559
564
end
560
- return dpadj
565
+ return dpadj:: LinearAlgebra.Adjoint{Float64, Vector{Float64}}
561
566
end
562
567
563
568
function getadjointsensitivities (syssim:: Q , bsol:: W3 , target:: String , solver:: W ; sensalg:: W2 = InterpolatingAdjoint (autojacvec= ReverseDiffVJP (false )),
@@ -591,7 +596,8 @@ function getadjointsensitivities(syssim::Q, bsol::W3, target::String, solver::W;
591
596
end
592
597
dgdu (out, y, p, t) = ForwardDiff. gradient! (out, y -> g (y, p, t), y)
593
598
dgdp (out, y, p, t) = ForwardDiff. gradient! (out, p -> g (y, p, t), p)
594
- du0, dpadj = adjoint_sensitivities (syssim. sol, solver, g, nothing , (dgdu, dgdp); sensealg= sensalg, abstol= abstol, reltol= reltol, kwargs... )
599
+ du0, dpadj = adjoint_sensitivities (syssim. sol, solver; g= g, dgdu_continuous= dgdu, dgdp_continuous= dgdp,
600
+ sensealg= sensalg, abstol= abstol, reltol= reltol, kwargs... )
595
601
if normalize
596
602
for domain in domains
597
603
dpadj[domain. parameterindexes[1 ]+ length (domain. phase. species): domain. parameterindexes[2 ]] .*= syssim. p[domain. parameterindexes[1 ]+ length (domain. phase. species): domain. parameterindexes[2 ]]
@@ -600,7 +606,7 @@ function getadjointsensitivities(syssim::Q, bsol::W3, target::String, solver::W;
600
606
dpadj ./= bsol. sol (bsol. sol. t[end ])[ind]
601
607
end
602
608
end
603
- return dpadj
609
+ return dpadj:: LinearAlgebra.Adjoint{Float64, Vector{Float64}}
604
610
end
605
611
export getadjointsensitivities
606
612
0 commit comments