Skip to content

Commit 4e5d31e

Browse files
author
Hadrien
committed
Implement OBABO, remains some error on weak order
1 parent 40a1d04 commit 4e5d31e

File tree

3 files changed

+165
-72
lines changed

3 files changed

+165
-72
lines changed

src/caches/dynamical_caches.jl

Lines changed: 40 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,29 @@ abstract type StochasticDynamicalEqConstantCache <: StochasticDiffEqConstantCach
22
abstract type StochasticDynamicalEqMutableCache <: StochasticDiffEqMutableCache end
33

44

5-
mutable struct BAOABConstantCache{uType,uEltypeNoUnits,uCoeffType} <: StochasticDynamicalEqConstantCache
5+
mutable struct BAOABConstantCache{uType,uEltypeNoUnits,uCoeffType, uCoeffMType} <: StochasticDynamicalEqConstantCache
66
k::uType
77
half::uEltypeNoUnits
88
c1::uCoeffType
9-
c2::uCoeffType
9+
c2::uCoeffMType
1010
end
11-
@cache struct BAOABCache{uType,uEltypeNoUnits,rateNoiseType,uCoeffType,uTypeCombined} <: StochasticDynamicalEqMutableCache
11+
@cache struct BAOABCache{uType,uEltypeNoUnits,rateNoiseType,uCoeffType, uCoeffMType,uTypeCombined} <: StochasticDynamicalEqMutableCache
1212
utmp::uType
13+
dumid::uType
1314
dutmp::uType
15+
dunoise::uType
1416
k::uType
1517
gtmp::rateNoiseType
1618
noise::uType
1719
half::uEltypeNoUnits
1820
c1::uCoeffType
19-
c2::uCoeffType
21+
c2::uCoeffMType
2022
tmp::uTypeCombined
2123
end
2224

2325
function alg_cache(alg::BAOAB,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{false}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
2426
k = zero(rate_prototype.x[1])
25-
if typeof(alg.gamma) <: Number
26-
c1 = exp.(-alg.gamma*dt)
27-
c2 = sqrt.(1 .- alg.scale_noise*c1.^2)# if scale_noise == false, c2 = 1
28-
elseif typeof(alg.gamma) <: AbstractMatrix
27+
if typeof(alg.gamma) <: AbstractMatrix
2928
c1 = exp(-alg.gamma*dt)
3029
c2 = cholesky(I - alg.scale_noise*c1*transpose(c1)).U# if scale_noise == false, c2 = 1
3130
else
@@ -36,18 +35,18 @@ function alg_cache(alg::BAOAB,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototy
3635
end
3736

3837
function alg_cache(alg::BAOAB,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{true}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
38+
dumid = zero(u.x[1])
3939
dutmp = zero(u.x[1])
40+
dunoise = zero(u.x[1])
4041
utmp = zero(u.x[2])
4142
k = zero(rate_prototype.x[1])
4243

4344
gtmp = zero(noise_rate_prototype)
4445
noise = zero(rate_prototype.x[1])
4546

4647
half = uEltypeNoUnits(1//2)
47-
if typeof(alg.gamma) <: Number
48-
c1 = exp.(-alg.gamma*dt)
49-
c2 = sqrt.(1 .- alg.scale_noise*c1.^2)# if scale_noise == false, c2 = 1
50-
elseif typeof(alg.gamma) <: AbstractMatrix
48+
49+
if typeof(alg.gamma) <: AbstractMatrix
5150
c1 = exp(-alg.gamma*dt)
5251
c2 = cholesky(I - alg.scale_noise*c1*transpose(c1)).U# if scale_noise == false, c2 = 1
5352
else
@@ -57,36 +56,35 @@ function alg_cache(alg::BAOAB,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototy
5756

5857
tmp = zero(u)
5958

60-
BAOABCache(utmp, dutmp, k, gtmp, noise, half, c1, c2, tmp)
59+
BAOABCache(utmp, dumid, dutmp, dunoise, k, gtmp, noise, half, c1, c2, tmp)
6160
end
6261

6362

6463

65-
mutable struct ABOBAConstantCache{uType,uEltypeNoUnits, uCoeffType} <: StochasticDynamicalEqConstantCache
64+
mutable struct ABOBAConstantCache{uType,uEltypeNoUnits, uCoeffType, uCoeffMType} <: StochasticDynamicalEqConstantCache
6665
k::uType
6766
half::uEltypeNoUnits
6867
c₂::uCoeffType
69-
σ::uCoeffType
68+
σ::uCoeffMType
7069
end
71-
@cache struct ABOBACache{uType,uEltypeNoUnits,rateNoiseType,uCoeffType,uTypeCombined} <: StochasticDynamicalEqMutableCache
70+
@cache struct ABOBACache{uType,uEltypeNoUnits,rateNoiseType,uCoeffType, uCoeffMType,uTypeCombined} <: StochasticDynamicalEqMutableCache
7271
utmp::uType
72+
dumid::uType
7373
dutmp::uType
74+
dunoise::uType
7475
k::uType
7576
gtmp::rateNoiseType
7677
noise::uType
7778
half::uEltypeNoUnits
7879
c₂::uCoeffType
79-
σ::uCoeffType
80+
σ::uCoeffMType
8081
tmp::uTypeCombined
8182
end
8283

8384
function alg_cache(alg::ABOBA,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{false}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
8485
k = zero(rate_prototype.x[1])
8586

86-
if typeof(alg.gamma) <: Number
87-
c₂ = exp.(-alg.gamma*dt)
88-
σ = sqrt.(1 .- alg.scale_noise*c₂.^2)
89-
elseif typeof(alg.gamma) <: AbstractMatrix
87+
if typeof(alg.gamma) <: AbstractMatrix
9088
c₂ = exp(-alg.gamma*dt)
9189
σ = cholesky(I - alg.scale_noise*c₂*transpose(c₂)).U
9290
else
@@ -99,6 +97,8 @@ end
9997

10098
function alg_cache(alg::ABOBA,prob,u,ΔW,ΔZ,p,rate_prototype, noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{true}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
10199
dutmp = zero(u.x[1])
100+
dumid = zero(u.x[1])
101+
dunoise = zero(u.x[1])
102102
utmp = zero(u.x[2])
103103
k = zero(rate_prototype.x[1])
104104

@@ -107,10 +107,7 @@ function alg_cache(alg::ABOBA,prob,u,ΔW,ΔZ,p,rate_prototype, noise_rate_protot
107107

108108
half = uEltypeNoUnits(1//2)
109109

110-
if typeof(alg.gamma) <: Number
111-
c₂ = exp.(-alg.gamma*dt)
112-
σ = sqrt.(1 .- alg.scale_noise*c₂.^2)
113-
elseif typeof(alg.gamma) <: AbstractMatrix
110+
if typeof(alg.gamma) <: AbstractMatrix
114111
c₂ = exp(-alg.gamma*dt)
115112
σ = cholesky(I - alg.scale_noise*c₂*transpose(c₂)).U
116113
else
@@ -120,64 +117,64 @@ function alg_cache(alg::ABOBA,prob,u,ΔW,ΔZ,p,rate_prototype, noise_rate_protot
120117

121118
tmp = zero(u)
122119

123-
ABOBACache(utmp, dutmp, k, gtmp, noise, half, c₂, σ, tmp)
120+
ABOBACache(utmp, dumid, dutmp, dunoise, k, gtmp, noise, half, c₂, σ, tmp)
124121
end
125122

126123

127124

128125

129-
mutable struct OBABOConstantCache{uType,uEltypeNoUnits, uCoeffType} <: StochasticDynamicalEqConstantCache
126+
mutable struct OBABOConstantCache{uType,rateNoiseType, uEltypeNoUnits, uCoeffType, uCoeffMType} <: StochasticDynamicalEqConstantCache
130127
k::uType
128+
sig::rateNoiseType
131129
half::uEltypeNoUnits
132130
c₂::uCoeffType
133-
σ::uCoeffType
131+
σ::uCoeffMType
134132
end
135-
@cache struct OBABOCache{uType,uEltypeNoUnits,rateNoiseType,uCoeffType,uTypeCombined} <: StochasticDynamicalEqMutableCache
133+
134+
@cache struct OBABOCache{uType,uEltypeNoUnits,rateNoiseType,uCoeffType, uCoeffMType,uTypeCombined} <: StochasticDynamicalEqMutableCache
136135
utmp::uType
136+
dumid::uType
137137
dutmp::uType
138+
dunoise::uType
138139
k::uType
139140
gtmp::rateNoiseType
140141
noise::uType
141-
noisetmp::uType
142142
half::uEltypeNoUnits
143143
c₂::uCoeffType
144-
σ::uCoeffType
144+
σ::uCoeffMType
145145
tmp::uTypeCombined
146146
end
147147

148148
function alg_cache(alg::OBABO,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{false}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
149149
k = zero(rate_prototype.x[1])
150+
sig = zero(noise_rate_prototype)
150151
half=uEltypeNoUnits(1//2)
151-
if typeof(alg.gamma) <: Number
152-
c₂ = exp.(-alg.gamma*half*dt)
153-
σ = sqrt.(1 .- alg.scale_noise*c₂.^2)
154-
elseif typeof(alg.gamma) <: AbstractMatrix
152+
153+
if typeof(alg.gamma) <: AbstractMatrix
155154
c₂ = exp(-alg.gamma*half*dt)
156155
σ = cholesky(I - alg.scale_noise*c₂*transpose(c₂)).U
157156
else
158157
c₂ = exp.(-alg.gamma*half*dt)
159158
σ = sqrt.(1 .- alg.scale_noise*c₂.^2)
160159
end
161160
# if scale_noise == false, c2 = 1
162-
OBABOConstantCache(k, half, c₂, σ)
161+
OBABOConstantCache(k, sig, half, c₂, σ)
163162
end
164163

165164
function alg_cache(alg::OBABO,prob,u,ΔW,ΔZ,p,rate_prototype, noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{true}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
166165
dutmp = zero(u.x[1])
166+
dumid = zero(u.x[1])
167+
dunoise = zero(u.x[1])
167168
utmp = zero(u.x[2])
168169
k = zero(rate_prototype.x[1])
169-
noise = zero(rate_prototype.x[1])
170170

171171
gtmp = zero(noise_rate_prototype)
172-
noisetmp = zero(rate_prototype.x[1])
172+
noise = zero(rate_prototype.x[1])
173173

174174

175175
half = uEltypeNoUnits(1//2)
176176

177-
if typeof(alg.gamma) <: Number
178-
c₂ = exp.(-alg.gamma*half*dt)
179-
σ = sqrt.(1 .- alg.scale_noise*c₂.^2)
180-
elseif typeof(alg.gamma) <: AbstractMatrix
177+
if typeof(alg.gamma) <: AbstractMatrix
181178
c₂ = exp(-alg.gamma*half*dt)
182179
σ = cholesky(I - alg.scale_noise*c₂*transpose(c₂)).U
183180
else
@@ -187,5 +184,5 @@ function alg_cache(alg::OBABO,prob,u,ΔW,ΔZ,p,rate_prototype, noise_rate_protot
187184

188185
tmp = zero(u)
189186

190-
OBABOCache(utmp, dutmp, k, noise, gtmp, noisetmp, half, c₂, σ, tmp)
187+
OBABOCache(utmp, dumid, dutmp, dunoise, k, gtmp, noise, half, c₂, σ, tmp)
191188
end

src/perform_step/dynamical.jl

Lines changed: 97 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ end
4949
du3 = c1.*du2 + c2.*noise
5050
end
5151

52-
5352
# A
5453
u = u2 + half*dt*du3
5554

@@ -62,25 +61,25 @@ end
6261

6362
@muladd function perform_step!(integrator,cache::BAOABCache)
6463
@unpack t,dt,sqdt,uprev,u,p,W,f = integrator
65-
@unpack utmp, dutmp, k, gtmp, noise, half, c1, c2 = cache
64+
@unpack utmp, dumid, dutmp, dunoise, k, gtmp, noise, half, c1, c2 = cache
6665
du1 = uprev.x[1]
6766
u1 = uprev.x[2]
6867

6968
# B
70-
@.. dutmp = du1 + half*dt*k
69+
@.. dumid = du1 + half*dt*k
7170

7271
# A
73-
@.. utmp = u1 + half*dt*dutmp
72+
@.. utmp = u1 + half*dt*dumid
7473

7574
# O
7675
integrator.g(gtmp,utmp,p,t+dt*half)
7776
@.. noise = gtmp*W.dW / sqdt
7877
if typeof(c2) <: AbstractMatrix
79-
mul!(dutmp,c1,dutmp)
80-
mul!(noise,c2,noise)
81-
@.. dutmp+= noise
78+
mul!(dutmp,c1,dumid)
79+
mul!(dunoise,c2,noise)
80+
@.. dutmp+= dunoise
8281
else
83-
@.. dutmp = c1*dutmp + c2*noise
82+
@.. dutmp = c1*dumid + c2*noise
8483
end
8584

8685
# A
@@ -119,7 +118,7 @@ end
119118

120119
@muladd function perform_step!(integrator,cache::ABOBACache)
121120
@unpack t,dt,sqdt,uprev,u,p,W,f = integrator
122-
@unpack utmp, dutmp, k, gtmp, noise, half, c₂, σ = cache
121+
@unpack utmp, dumid, dutmp, dunoise, k, gtmp, noise, half, c₂, σ = cache
123122
du1 = uprev.x[1]
124123
u1 = uprev.x[2]
125124

@@ -129,26 +128,26 @@ end
129128

130129
# B
131130
f.f1(k,du1,utmp,p,t+dt)
132-
@.. dutmp = du1 + half*dt*k
131+
@.. dumid = du1 + half*dt*k
133132

134133
# O
135134
integrator.g(gtmp,utmp,p,t+dt*half)
136135
@.. noise = gtmp*W.dW / sqdt
137136

138137
if typeof(σ) <: AbstractMatrix
139-
mul!(dutmp,c₂,dutmp)
140-
mul!(noise,σ,noise)
141-
@.. dutmp+=noise
138+
mul!(dutmp,c₂,dumid)
139+
mul!(dunoise,σ,noise)
140+
@.. dutmp+=dunoise
142141
else
143-
@.. dutmp = c₂*dutmp + σ*noise
142+
@.. dutmp = c₂*dumid + σ*noise
144143
end
145144

146145

147146
# B
148147
@.. u.x[1] = dutmp + half*dt*k
149148

150149
# A: xt+1
151-
@.. u.x[2] = utmp + half*dt*dutmp
150+
@.. u.x[2] = utmp + half*dt*u.x[1]
152151
end
153152

154153

@@ -161,6 +160,7 @@ function initialize!(integrator, cache::OBABOConstantCache)
161160

162161
verify_f2(integrator.f.f2, du1, u1, p, t, integrator, cache)
163162
cache.k = integrator.f.f1(du1,u1,p,t)
163+
cache.sig = integrator.g(u1,p,t)
164164
end
165165

166166
function initialize!(integrator, cache::OBABOCache)
@@ -170,6 +170,87 @@ function initialize!(integrator, cache::OBABOCache)
170170

171171
verify_f2(integrator.f.f2, cache.k, du1, u1, p, t, integrator, cache)
172172
integrator.f.f1(cache.k,du1,u1,p,t)
173-
174173
integrator.g(cache.gtmp,u1,p,t)
175174
end
175+
176+
177+
@muladd function perform_step!(integrator,cache::OBABOConstantCache)
178+
@unpack t,dt,sqdt,uprev,u,p,W,f = integrator
179+
@unpack half, c₂, σ = cache
180+
du1 = uprev.x[1]
181+
u1 = uprev.x[2]
182+
183+
# O
184+
noise = cache.sig.*W.dW ./ sqdt
185+
if typeof(σ) <: AbstractMatrix || typeof(noise) <: Number
186+
du2 = c₂*du1 + σ*noise
187+
else
188+
du2 = c₂.*du1 + σ.*noise
189+
end
190+
191+
# B
192+
dumid = du2 + half*dt*cache.k
193+
194+
# A
195+
u = u1 + dt*dumid
196+
197+
cache.k = f.f1(dumid,u,p,t+dt)
198+
# B
199+
du3 = dumid + half*dt*cache.k
200+
201+
# O
202+
cache.sig = integrator.g(u,p,t+dt)
203+
noise = cache.sig.*W.dW ./ sqdt # That should be a second noise
204+
if typeof(σ) <: AbstractMatrix || typeof(noise) <: Number
205+
du = c₂*du3 + σ*noise
206+
else
207+
du = c₂.*du3 + σ.*noise
208+
end
209+
210+
integrator.u = ArrayPartition((du, u))
211+
end
212+
213+
214+
@muladd function perform_step!(integrator,cache::OBABOCache)
215+
@unpack t,dt,sqdt,uprev,u,p,W,f = integrator
216+
@unpack utmp, dumid, dutmp, dunoise, k, gtmp, noise, half, c₂, σ = cache
217+
du1 = uprev.x[1]
218+
u1 = uprev.x[2]
219+
220+
# O
221+
@.. noise = gtmp*W.dW / sqdt
222+
223+
if typeof(σ) <: AbstractMatrix
224+
mul!(dutmp,c₂,du1)
225+
mul!(dunoise,σ,noise)
226+
@.. dutmp+=dunoise
227+
else
228+
@.. dutmp = c₂*du1 + σ*noise
229+
end
230+
231+
# B
232+
233+
@.. dumid = dutmp + half*dt*k
234+
235+
# A: xt+1
236+
@.. u.x[2] = u1 + dt*dumid
237+
238+
239+
# B
240+
f.f1(k,dumid,u.x[2],p,t+dt)
241+
@.. dutmp = dumid + half*dt*k
242+
243+
# O
244+
integrator.g(gtmp,u.x[2],p,t+dt)
245+
@.. noise = gtmp*W.dW / sqdt # That should be a second noise
246+
247+
if typeof(σ) <: AbstractMatrix
248+
mul!(u.x[1],c₂,dutmp)
249+
mul!(dunoise,σ,noise)
250+
@.. u.x[1]+=dunoise
251+
else
252+
@.. u.x[1] = c₂*dutmp + σ*noise
253+
end
254+
255+
256+
end

0 commit comments

Comments
 (0)