Skip to content

Commit 40a1d04

Browse files
author
Hadrien
committed
Add ABOBA, start implement OBABO, check to be done on weak order
1 parent 66f4b23 commit 40a1d04

File tree

6 files changed

+338
-25
lines changed

6 files changed

+338
-25
lines changed

src/StochasticDiffEq.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ end
169169

170170
export TauLeaping, CaoTauLeaping
171171

172-
export BAOAB
172+
export BAOAB, ABOBA, OBABO
173173

174174
export StochasticDiffEqRODEAlgorithm, StochasticDiffEqRODEAdaptiveAlgorithm,
175175
StochasticDiffEqRODECompositeAlgorithm

src/alg_utils.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ alg_order(alg::TauLeaping) = 1 // 1
122122
alg_order(alg::CaoTauLeaping) = 1 // 1
123123

124124
alg_order(alg::BAOAB) = 1 // 1
125+
alg_order(alg::ABOBA) = 2 // 1
126+
alg_order(alg::OBABO) = 2 // 1
125127

126128
alg_order(alg::SKenCarp) = 2 // 1
127129
alg_order(alg::Union{StochasticDiffEqCompositeAlgorithm,StochasticDiffEqRODECompositeAlgorithm}) = maximum(alg_order.(alg.algs))
@@ -225,6 +227,8 @@ alg_compatible(prob::DiffEqBase.AbstractSDEProblem, alg::IIF1M) = true
225227
alg_compatible(prob::DiffEqBase.AbstractSDEProblem, alg::IIF2M) = true
226228
alg_compatible(prob::DiffEqBase.AbstractSDEProblem, alg::Union{StochasticDiffEqCompositeAlgorithm,StochasticDiffEqRODECompositeAlgorithm}) = max((alg_compatible(prob, a) for a in alg.algs)...)
227229
alg_compatible(prob::DiffEqBase.AbstractSDEProblem, alg::BAOAB) = is_diagonal_noise(prob)
230+
alg_compatible(prob::DiffEqBase.AbstractSDEProblem, alg::ABOBA) = is_diagonal_noise(prob)
231+
alg_compatible(prob::DiffEqBase.AbstractSDEProblem, alg::OBABO) = is_diagonal_noise(prob)
228232

229233
function alg_compatible(prob::JumpProblem, alg::Union{StochasticDiffEqJumpAdaptiveAlgorithm,StochasticDiffEqJumpAlgorithm})
230234
prob.prob isa DiscreteProblem

src/algorithms.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,3 +869,16 @@ struct BAOAB{T} <: StochasticDiffEqAlgorithm
869869
scale_noise::Bool
870870
end
871871
BAOAB(;gamma=1.0, scale_noise=true) = BAOAB(gamma, scale_noise)
872+
873+
struct ABOBA{T} <: StochasticDiffEqAlgorithm
874+
gamma::T
875+
scale_noise::Bool
876+
end
877+
ABOBA(;gamma=1.0, scale_noise=true) = ABOBA(gamma, scale_noise)
878+
879+
880+
struct OBABO{T} <: StochasticDiffEqAlgorithm
881+
gamma::T
882+
scale_noise::Bool
883+
end
884+
OBABO(;gamma=1.0, scale_noise=true) = OBABO(gamma, scale_noise)

src/caches/dynamical_caches.jl

Lines changed: 164 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,191 @@
1+
abstract type StochasticDynamicalEqConstantCache <: StochasticDiffEqConstantCache end # Pourquoi faire ça, Si c'est pour avoir une seul function de check dans initialize!
2+
abstract type StochasticDynamicalEqMutableCache <: StochasticDiffEqMutableCache end
13

2-
mutable struct BAOABConstantCache{uType,uEltypeNoUnits} <: StochasticDiffEqConstantCache
4+
5+
mutable struct BAOABConstantCache{uType,uEltypeNoUnits,uCoeffType} <: StochasticDynamicalEqConstantCache
36
k::uType
47
half::uEltypeNoUnits
5-
c1::uEltypeNoUnits
6-
c2::uEltypeNoUnits
8+
c1::uCoeffType
9+
c2::uCoeffType
710
end
8-
@cache struct BAOABCache{uType,uEltypeNoUnits,rateNoiseType,uTypeCombined} <: StochasticDiffEqMutableCache
11+
@cache struct BAOABCache{uType,uEltypeNoUnits,rateNoiseType,uCoeffType,uTypeCombined} <: StochasticDynamicalEqMutableCache
912
utmp::uType
1013
dutmp::uType
1114
k::uType
12-
gtmp::uType
13-
noise::rateNoiseType
15+
gtmp::rateNoiseType
16+
noise::uType
1417
half::uEltypeNoUnits
15-
c1::uEltypeNoUnits
16-
c2::uEltypeNoUnits
18+
c1::uCoeffType
19+
c2::uCoeffType
1720
tmp::uTypeCombined
1821
end
1922

2023
function alg_cache(alg::BAOAB,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{false}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
2124
k = zero(rate_prototype.x[1])
22-
c1 = exp(-alg.gamma*dt)
23-
c2 = sqrt(1 - alg.scale_noise*c1^2) # if scale_noise == false, c2 = 1
24-
BAOABConstantCache(k, uEltypeNoUnits(1//2), uEltypeNoUnits(c1), uEltypeNoUnits(c2))
25+
if typeof(alg.gamma) <: Number
26+
c1 = exp.(-alg.gamma*dt)
27+
c2 = sqrt.(1 .- alg.scale_noise*c1.^2)# if scale_noise == false, c2 = 1
28+
elseif typeof(alg.gamma) <: AbstractMatrix
29+
c1 = exp(-alg.gamma*dt)
30+
c2 = cholesky(I - alg.scale_noise*c1*transpose(c1)).U# if scale_noise == false, c2 = 1
31+
else
32+
c1 = exp.(-alg.gamma*dt)
33+
c2 = sqrt.(1 .- alg.scale_noise*c1.^2)# if scale_noise == false, c2 = 1
34+
end
35+
BAOABConstantCache(k, uEltypeNoUnits(1//2),c1, c2)
2536
end
2637

2738
function alg_cache(alg::BAOAB,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{true}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
2839
dutmp = zero(u.x[1])
2940
utmp = zero(u.x[2])
3041
k = zero(rate_prototype.x[1])
3142

32-
gtmp = zero(rate_prototype.x[1])
43+
gtmp = zero(noise_rate_prototype)
44+
noise = zero(rate_prototype.x[1])
45+
46+
half = uEltypeNoUnits(1//2)
47+
if typeof(alg.gamma) <: Number
48+
c1 = exp.(-alg.gamma*dt)
49+
c2 = sqrt.(1 .- alg.scale_noise*c1.^2)# if scale_noise == false, c2 = 1
50+
elseif typeof(alg.gamma) <: AbstractMatrix
51+
c1 = exp(-alg.gamma*dt)
52+
c2 = cholesky(I - alg.scale_noise*c1*transpose(c1)).U# if scale_noise == false, c2 = 1
53+
else
54+
c1 = exp.(-alg.gamma*dt)
55+
c2 = sqrt.(1 .- alg.scale_noise*c1.^2)# if scale_noise == false, c2 = 1
56+
end
57+
58+
tmp = zero(u)
59+
60+
BAOABCache(utmp, dutmp, k, gtmp, noise, half, c1, c2, tmp)
61+
end
62+
63+
64+
65+
mutable struct ABOBAConstantCache{uType,uEltypeNoUnits, uCoeffType} <: StochasticDynamicalEqConstantCache
66+
k::uType
67+
half::uEltypeNoUnits
68+
c₂::uCoeffType
69+
σ::uCoeffType
70+
end
71+
@cache struct ABOBACache{uType,uEltypeNoUnits,rateNoiseType,uCoeffType,uTypeCombined} <: StochasticDynamicalEqMutableCache
72+
utmp::uType
73+
dutmp::uType
74+
k::uType
75+
gtmp::rateNoiseType
76+
noise::uType
77+
half::uEltypeNoUnits
78+
c₂::uCoeffType
79+
σ::uCoeffType
80+
tmp::uTypeCombined
81+
end
82+
83+
function alg_cache(alg::ABOBA,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{false}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
84+
k = zero(rate_prototype.x[1])
85+
86+
if typeof(alg.gamma) <: Number
87+
c₂ = exp.(-alg.gamma*dt)
88+
σ = sqrt.(1 .- alg.scale_noise*c₂.^2)
89+
elseif typeof(alg.gamma) <: AbstractMatrix
90+
c₂ = exp(-alg.gamma*dt)
91+
σ = cholesky(I - alg.scale_noise*c₂*transpose(c₂)).U
92+
else
93+
c₂ = exp.(-alg.gamma*dt)
94+
σ = sqrt.(1 .- alg.scale_noise*c₂.^2)
95+
end
96+
# if scale_noise == false, c2 = 1
97+
ABOBAConstantCache(k, uEltypeNoUnits(1//2), c₂, σ)
98+
end
99+
100+
function alg_cache(alg::ABOBA,prob,u,ΔW,ΔZ,p,rate_prototype, noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{true}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
101+
dutmp = zero(u.x[1])
102+
utmp = zero(u.x[2])
103+
k = zero(rate_prototype.x[1])
104+
105+
gtmp = zero(noise_rate_prototype)
106+
noise = zero(rate_prototype.x[1])
107+
108+
half = uEltypeNoUnits(1//2)
109+
110+
if typeof(alg.gamma) <: Number
111+
c₂ = exp.(-alg.gamma*dt)
112+
σ = sqrt.(1 .- alg.scale_noise*c₂.^2)
113+
elseif typeof(alg.gamma) <: AbstractMatrix
114+
c₂ = exp(-alg.gamma*dt)
115+
σ = cholesky(I - alg.scale_noise*c₂*transpose(c₂)).U
116+
else
117+
c₂ = exp.(-alg.gamma*dt)
118+
σ = sqrt.(1 .- alg.scale_noise*c₂.^2)
119+
end
120+
121+
tmp = zero(u)
122+
123+
ABOBACache(utmp, dutmp, k, gtmp, noise, half, c₂, σ, tmp)
124+
end
125+
126+
127+
128+
129+
mutable struct OBABOConstantCache{uType,uEltypeNoUnits, uCoeffType} <: StochasticDynamicalEqConstantCache
130+
k::uType
131+
half::uEltypeNoUnits
132+
c₂::uCoeffType
133+
σ::uCoeffType
134+
end
135+
@cache struct OBABOCache{uType,uEltypeNoUnits,rateNoiseType,uCoeffType,uTypeCombined} <: StochasticDynamicalEqMutableCache
136+
utmp::uType
137+
dutmp::uType
138+
k::uType
139+
gtmp::rateNoiseType
140+
noise::uType
141+
noisetmp::uType
142+
half::uEltypeNoUnits
143+
c₂::uCoeffType
144+
σ::uCoeffType
145+
tmp::uTypeCombined
146+
end
147+
148+
function alg_cache(alg::OBABO,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{false}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
149+
k = zero(rate_prototype.x[1])
150+
half=uEltypeNoUnits(1//2)
151+
if typeof(alg.gamma) <: Number
152+
c₂ = exp.(-alg.gamma*half*dt)
153+
σ = sqrt.(1 .- alg.scale_noise*c₂.^2)
154+
elseif typeof(alg.gamma) <: AbstractMatrix
155+
c₂ = exp(-alg.gamma*half*dt)
156+
σ = cholesky(I - alg.scale_noise*c₂*transpose(c₂)).U
157+
else
158+
c₂ = exp.(-alg.gamma*half*dt)
159+
σ = sqrt.(1 .- alg.scale_noise*c₂.^2)
160+
end
161+
# if scale_noise == false, c2 = 1
162+
OBABOConstantCache(k, half, c₂, σ)
163+
end
164+
165+
function alg_cache(alg::OBABO,prob,u,ΔW,ΔZ,p,rate_prototype, noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{true}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
166+
dutmp = zero(u.x[1])
167+
utmp = zero(u.x[2])
168+
k = zero(rate_prototype.x[1])
33169
noise = zero(rate_prototype.x[1])
34170

171+
gtmp = zero(noise_rate_prototype)
172+
noisetmp = zero(rate_prototype.x[1])
173+
174+
35175
half = uEltypeNoUnits(1//2)
36-
c1 = exp(-alg.gamma*dt)
37-
c2 = sqrt(1 - alg.scale_noise*c1^2) # if scale_noise == false, c2 = 1
176+
177+
if typeof(alg.gamma) <: Number
178+
c₂ = exp.(-alg.gamma*half*dt)
179+
σ = sqrt.(1 .- alg.scale_noise*c₂.^2)
180+
elseif typeof(alg.gamma) <: AbstractMatrix
181+
c₂ = exp(-alg.gamma*half*dt)
182+
σ = cholesky(I - alg.scale_noise*c₂*transpose(c₂)).U
183+
else
184+
c₂ = exp.(-alg.gamma*half*dt)
185+
σ = sqrt.(1 .- alg.scale_noise*c₂.^2)
186+
end
38187

39188
tmp = zero(u)
40189

41-
BAOABCache(utmp, dutmp, k, gtmp, noise, half, uEltypeNoUnits(c1), uEltypeNoUnits(c2), tmp)
190+
OBABOCache(utmp, dutmp, k, noise, gtmp, noisetmp, half, c₂, σ, tmp)
42191
end

src/perform_step/dynamical.jl

Lines changed: 101 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
function verify_f2(f, p, q, pa, t, integrator, ::BAOABConstantCache)
1+
function verify_f2(f, p, q, pa, t, integrator, ::StochasticDynamicalEqConstantCache)
22
res = f(p, q, pa, t)
33
res != p && throwex(integrator)
44
end
5-
function verify_f2(f, res, p, q, pa, t, integrator, ::BAOABCache)
5+
function verify_f2(f, res, p, q, pa, t, integrator, ::StochasticDiffEqMutableCache)
66
f(res, p, q, pa, t)
77
res != p && throwex(integrator)
88
end
@@ -11,7 +11,7 @@ function throwex(integrator)
1111
throw(ArgumentError("Algorithm $algn is not applicable if f2(p, q, t) != p"))
1212
end
1313

14-
function initialize!(integrator, cache::BAOABConstantCache)
14+
function initialize!(integrator, cache::Union{BAOABConstantCache,ABOBAConstantCache})
1515
@unpack t,dt,uprev,u,p,W = integrator
1616
du1 = integrator.uprev.x[1]
1717
u1 = integrator.uprev.x[2]
@@ -20,7 +20,7 @@ function initialize!(integrator, cache::BAOABConstantCache)
2020
cache.k = integrator.f.f1(du1,u1,p,t)
2121
end
2222

23-
function initialize!(integrator, cache::BAOABCache)
23+
function initialize!(integrator, cache::Union{BAOABCache,ABOBACache})
2424
@unpack t,dt,uprev,u,p,W = integrator
2525
du1 = integrator.uprev.x[1]
2626
u1 = integrator.uprev.x[2]
@@ -42,8 +42,13 @@ end
4242
u2 = u1 + half*dt*du2
4343

4444
# O
45-
noise = integrator.g(u2,p,t+dt*half).*W.dW / sqdt
46-
du3 = c1*du2 + c2*noise
45+
noise = integrator.g(u2,p,t+dt*half).*W.dW ./ sqdt
46+
if typeof(c2) <: AbstractMatrix || typeof(noise) <: Number
47+
du3 = c1*du2 + c2*noise
48+
else
49+
du3 = c1.*du2 + c2.*noise
50+
end
51+
4752

4853
# A
4954
u = u2 + half*dt*du3
@@ -70,7 +75,13 @@ end
7075
# O
7176
integrator.g(gtmp,utmp,p,t+dt*half)
7277
@.. noise = gtmp*W.dW / sqdt
73-
@.. dutmp = c1*dutmp + c2*noise
78+
if typeof(c2) <: AbstractMatrix
79+
mul!(dutmp,c1,dutmp)
80+
mul!(noise,c2,noise)
81+
@.. dutmp+= noise
82+
else
83+
@.. dutmp = c1*dutmp + c2*noise
84+
end
7485

7586
# A
7687
@.. u.x[2] = utmp + half*dt*dutmp
@@ -79,3 +90,86 @@ end
7990
f.f1(k,dutmp,u.x[2],p,t+dt)
8091
@.. u.x[1] = dutmp + half*dt*k
8192
end
93+
94+
95+
@muladd function perform_step!(integrator,cache::ABOBAConstantCache)
96+
@unpack t,dt,sqdt,uprev,u,p,W,f = integrator
97+
@unpack half, c₂, σ = cache
98+
du1 = uprev.x[1]
99+
u1 = uprev.x[2]
100+
101+
# A
102+
u_mid = u1 + half*dt*du1
103+
104+
# BOB: du_t+1
105+
cache.k = f.f1(du1,u_mid,p,t+half*dt)
106+
noise = integrator.g(u_mid,p,t+dt*half).*W.dW / sqdt
107+
108+
if typeof(σ) <: AbstractMatrix || typeof(noise) <: Number
109+
du = c₂ * (du1 + half*dt .* cache.k) .+ σ*noise .+ half * dt .*cache.k
110+
else
111+
du = c₂ .* (du1 + half*dt .* cache.k) .+ σ.*noise .+ half * dt .*cache.k
112+
end
113+
# A: xt+1
114+
u = u_mid .+ half * dt .*du
115+
116+
integrator.u = ArrayPartition((du, u))
117+
end
118+
119+
120+
@muladd function perform_step!(integrator,cache::ABOBACache)
121+
@unpack t,dt,sqdt,uprev,u,p,W,f = integrator
122+
@unpack utmp, dutmp, k, gtmp, noise, half, c₂, σ = cache
123+
du1 = uprev.x[1]
124+
u1 = uprev.x[2]
125+
126+
# A: xt+1/2
127+
@.. utmp = u1 + half*dt*du1
128+
129+
130+
# B
131+
f.f1(k,du1,utmp,p,t+dt)
132+
@.. dutmp = du1 + half*dt*k
133+
134+
# O
135+
integrator.g(gtmp,utmp,p,t+dt*half)
136+
@.. noise = gtmp*W.dW / sqdt
137+
138+
if typeof(σ) <: AbstractMatrix
139+
mul!(dutmp,c₂,dutmp)
140+
mul!(noise,σ,noise)
141+
@.. dutmp+=noise
142+
else
143+
@.. dutmp = c₂*dutmp + σ*noise
144+
end
145+
146+
147+
# B
148+
@.. u.x[1] = dutmp + half*dt*k
149+
150+
# A: xt+1
151+
@.. u.x[2] = utmp + half*dt*dutmp
152+
end
153+
154+
155+
156+
157+
function initialize!(integrator, cache::OBABOConstantCache)
158+
@unpack t,dt,uprev,u,p,W = integrator
159+
du1 = integrator.uprev.x[1]
160+
u1 = integrator.uprev.x[2]
161+
162+
verify_f2(integrator.f.f2, du1, u1, p, t, integrator, cache)
163+
cache.k = integrator.f.f1(du1,u1,p,t)
164+
end
165+
166+
function initialize!(integrator, cache::OBABOCache)
167+
@unpack t,dt,uprev,u,p,W = integrator
168+
du1 = integrator.uprev.x[1]
169+
u1 = integrator.uprev.x[2]
170+
171+
verify_f2(integrator.f.f2, cache.k, du1, u1, p, t, integrator, cache)
172+
integrator.f.f1(cache.k,du1,u1,p,t)
173+
174+
integrator.g(cache.gtmp,u1,p,t)
175+
end

0 commit comments

Comments
 (0)