@@ -41,12 +41,24 @@ mutable struct RosenbrockCache{uType, rateType, uNoUnitsType, JType, WType, TabT
41
41
alg:: A
42
42
step_limiter!:: StepLimiter
43
43
stage_limiter!:: StageLimiter
44
+ order:: Int
44
45
end
45
46
function full_cache (c:: RosenbrockCache )
46
47
return [c. u, c. uprev, c. dense... , c. du, c. du1, c. du2,
47
48
c. ks... , c. fsalfirst, c. fsallast, c. dT, c. tmp, c. atmp, c. weight, c. linsolve_tmp]
48
49
end
49
50
51
+ struct RosenbrockCombinedConstantCache{TF, UF, Tab, JType, WType, F, AD} <: RosenbrockConstantCache
52
+ tf:: TF
53
+ uf:: UF
54
+ tab:: Tab
55
+ J:: JType
56
+ W:: WType
57
+ linsolve:: F
58
+ autodiff:: AD
59
+ order:: Int
60
+ end
61
+
50
62
@cache mutable struct Rosenbrock23Cache{uType, rateType, uNoUnitsType, JType, WType,
51
63
TabType, TFType, UFType, F, JCType, GCType,
52
64
RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache
702
714
703
715
# ## Rodas4 methods
704
716
705
- struct Rodas4ConstantCache{TF, UF, Tab, JType, WType, F, AD} <: RosenbrockConstantCache
706
- tf:: TF
707
- uf:: UF
708
- tab:: Tab
709
- J:: JType
710
- W:: WType
711
- linsolve:: F
712
- autodiff:: AD
713
- end
714
-
715
717
tabtype (:: Rodas4 ) = Rodas4Tableau
716
718
tabtype (:: Rodas42 ) = Rodas42Tableau
717
719
tabtype (:: Rodas4P ) = Rodas4PTableau
@@ -727,10 +729,10 @@ function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2},
727
729
J, W = build_J_W (alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val (false ))
728
730
linprob = nothing # LinearProblem(W,copy(u); u0=copy(u))
729
731
linsolve = nothing # init(linprob,alg.linsolve,alias_A=true,alias_b=true)
730
- Rodas4ConstantCache (tf, uf,
732
+ RosenbrockCombinedConstantCache (tf, uf,
731
733
tabtype (alg)(constvalue (uBottomEltypeNoUnits),
732
734
constvalue (tTypeNoUnits)), J, W, linsolve,
733
- alg_autodiff (alg))
735
+ alg_autodiff (alg), 4 )
734
736
end
735
737
736
738
function alg_cache (alg:: Union{Rodas4, Rodas42, Rodas4P, Rodas4P2} ,
@@ -783,81 +785,22 @@ function alg_cache(alg::Union{Rodas4, Rodas42, Rodas4P, Rodas4P2},
783
785
u, uprev, dense, du, du1, du2, ks, fsalfirst, fsallast,
784
786
dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
785
787
linsolve, jac_config, grad_config, reltol, alg,
786
- alg. step_limiter!, alg. stage_limiter!)
788
+ alg. step_limiter!, alg. stage_limiter!, 4 )
787
789
end
788
790
789
791
# ###############################################################################
790
792
791
793
# ## Rosenbrock5
792
794
793
- struct Rosenbrock5ConstantCache{TF, UF, Tab, JType, WType, F} <: RosenbrockConstantCache
794
- tf:: TF
795
- uf:: UF
796
- tab:: Tab
797
- J:: JType
798
- W:: WType
799
- linsolve:: F
800
- end
801
-
802
- @cache mutable struct Rosenbrock5Cache{
803
- uType, rateType, uNoUnitsType, JType, WType, TabType,
804
- TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter, StageLimiter} < :
805
- RosenbrockMutableCache
806
- u:: uType
807
- uprev:: uType
808
- dense1:: rateType
809
- dense2:: rateType
810
- dense3:: rateType
811
- du:: rateType
812
- du1:: rateType
813
- du2:: rateType
814
- k1:: rateType
815
- k2:: rateType
816
- k3:: rateType
817
- k4:: rateType
818
- k5:: rateType
819
- k6:: rateType
820
- k7:: rateType
821
- k8:: rateType
822
- fsalfirst:: rateType
823
- fsallast:: rateType
824
- dT:: rateType
825
- J:: JType
826
- W:: WType
827
- tmp:: rateType
828
- atmp:: uNoUnitsType
829
- weight:: uNoUnitsType
830
- tab:: TabType
831
- tf:: TFType
832
- uf:: UFType
833
- linsolve_tmp:: rateType
834
- linsolve:: F
835
- jac_config:: JCType
836
- grad_config:: GCType
837
- reltol:: RTolType
838
- alg:: A
839
- step_limiter!:: StepLimiter
840
- stage_limiter!:: StageLimiter
841
- end
842
-
843
795
function alg_cache (alg:: Rodas5 , u, rate_prototype, :: Type{uEltypeNoUnits} ,
844
796
:: Type{uBottomEltypeNoUnits} , :: Type{tTypeNoUnits} , uprev, uprev2, f, t,
845
797
dt, reltol, p, calck,
846
798
:: Val{true} ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
847
- dense1 = zero (rate_prototype)
848
- dense2 = zero (rate_prototype)
849
- dense3 = zero (rate_prototype)
799
+ dense = [zero (rate_prototype) for _ in 1 : 3 ]
850
800
du = zero (rate_prototype)
851
801
du1 = zero (rate_prototype)
852
802
du2 = zero (rate_prototype)
853
- k1 = zero (rate_prototype)
854
- k2 = zero (rate_prototype)
855
- k3 = zero (rate_prototype)
856
- k4 = zero (rate_prototype)
857
- k5 = zero (rate_prototype)
858
- k6 = zero (rate_prototype)
859
- k7 = zero (rate_prototype)
860
- k8 = zero (rate_prototype)
803
+ ks = [zero (rate_prototype) for _ in 1 : 7 ]
861
804
fsalfirst = zero (rate_prototype)
862
805
fsallast = zero (rate_prototype)
863
806
dT = zero (rate_prototype)
@@ -881,12 +824,11 @@ function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits},
881
824
assumptions = LinearSolve. OperatorAssumptions (true ))
882
825
grad_config = build_grad_config (alg, f, tf, du1, t)
883
826
jac_config = build_jac_config (alg, f, uf, du1, uprev, u, tmp, du2)
884
- Rosenbrock5Cache (u, uprev, dense1, dense2, dense3, du, du1, du2, k1, k2, k3, k4,
885
- k5, k6, k7, k8,
827
+ RosenbrockCache (u, uprev, dense, du, du1, du2, ks,
886
828
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
887
829
linsolve_tmp,
888
830
linsolve, jac_config, grad_config, reltol, alg, alg. step_limiter!,
889
- alg. stage_limiter!)
831
+ alg. stage_limiter!, 5 )
890
832
end
891
833
892
834
function alg_cache (alg:: Rodas5 , u, rate_prototype, :: Type{uEltypeNoUnits} ,
@@ -898,30 +840,21 @@ function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits},
898
840
J, W = build_J_W (alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val (false ))
899
841
linprob = nothing # LinearProblem(W,copy(u); u0=copy(u))
900
842
linsolve = nothing # init(linprob,alg.linsolve,alias_A=true,alias_b=true)
901
- Rosenbrock5ConstantCache (tf, uf,
843
+ RosenbrockCombinedConstantCache (tf, uf,
902
844
Rodas5Tableau (constvalue (uBottomEltypeNoUnits),
903
- constvalue (tTypeNoUnits)), J, W, linsolve)
845
+ constvalue (tTypeNoUnits)), J, W, linsolve, alg_autodiff (alg), 5 )
904
846
end
905
847
906
848
function alg_cache (
907
849
alg:: Union{Rodas5P, Rodas5Pe, Rodas5Pr} , u, rate_prototype, :: Type{uEltypeNoUnits} ,
908
850
:: Type{uBottomEltypeNoUnits} , :: Type{tTypeNoUnits} , uprev, uprev2, f, t,
909
851
dt, reltol, p, calck,
910
852
:: Val{true} ) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
911
- dense1 = zero (rate_prototype)
912
- dense2 = zero (rate_prototype)
913
- dense3 = zero (rate_prototype)
853
+ dense = [zero (rate_prototype) for _ in 1 : 3 ]
914
854
du = zero (rate_prototype)
915
855
du1 = zero (rate_prototype)
916
856
du2 = zero (rate_prototype)
917
- k1 = zero (rate_prototype)
918
- k2 = zero (rate_prototype)
919
- k3 = zero (rate_prototype)
920
- k4 = zero (rate_prototype)
921
- k5 = zero (rate_prototype)
922
- k6 = zero (rate_prototype)
923
- k7 = zero (rate_prototype)
924
- k8 = zero (rate_prototype)
857
+ ks = [zero (rate_prototype) for _ in 1 : 8 ]
925
858
fsalfirst = zero (rate_prototype)
926
859
fsallast = zero (rate_prototype)
927
860
dT = zero (rate_prototype)
@@ -945,12 +878,11 @@ function alg_cache(
945
878
assumptions = LinearSolve. OperatorAssumptions (true ))
946
879
grad_config = build_grad_config (alg, f, tf, du1, t)
947
880
jac_config = build_jac_config (alg, f, uf, du1, uprev, u, tmp, du2)
948
- Rosenbrock5Cache (u, uprev, dense1, dense2, dense3, du, du1, du2, k1, k2, k3, k4,
949
- k5, k6, k7, k8,
881
+ RosenbrockCache (u, uprev, dense, du, du1, du2, ks,
950
882
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
951
883
linsolve_tmp,
952
884
linsolve, jac_config, grad_config, reltol, alg, alg. step_limiter!,
953
- alg. stage_limiter!)
885
+ alg. stage_limiter!, 5 )
954
886
end
955
887
956
888
function alg_cache (
@@ -963,9 +895,9 @@ function alg_cache(
963
895
J, W = build_J_W (alg, u, uprev, p, t, dt, f, uEltypeNoUnits, Val (false ))
964
896
linprob = nothing # LinearProblem(W,copy(u); u0=copy(u))
965
897
linsolve = nothing # init(linprob,alg.linsolve,alias_A=true,alias_b=true)
966
- Rosenbrock5ConstantCache (tf, uf,
898
+ RosenbrockCombinedConstantCache (tf, uf,
967
899
Rodas5PTableau (constvalue (uBottomEltypeNoUnits),
968
- constvalue (tTypeNoUnits)), J, W, linsolve)
900
+ constvalue (tTypeNoUnits)), J, W, linsolve, alg_autodiff (alg), 5 )
969
901
end
970
902
971
903
function get_fsalfirstlast (
0 commit comments