Skip to content

Commit 6063ca9

Browse files
committed
Fix strong stochastic RK methods and remaining SROCK methods
1 parent 8915570 commit 6063ca9

File tree

4 files changed

+56
-34
lines changed

4 files changed

+56
-34
lines changed

src/caches/SROCK_caches.jl

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,13 @@ function alg_cache(alg::SROCK2,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_protot
7979
uᵢ₋₂ = zero(u)
8080
Gₛ = zero(noise_rate_prototype)
8181
Gₛ₁ = zero(noise_rate_prototype)
82-
WikRange = false .* vec(ΔW)
83-
vec_χ = false .* vec(ΔW)
82+
if typeof(ΔW) <: Union{SArray,Number}
83+
WikRange = copy(ΔW)
84+
vec_χ = copy(ΔW)
85+
else
86+
WikRange = false .* vec(ΔW)
87+
vec_χ = false .* vec(ΔW)
88+
end
8489
tmp = uᵢ₋₂ # these 2 variables are dummied to use same memory
8590
fsalfirst = k
8691
atmp = zero(rate_prototype)
@@ -125,7 +130,11 @@ function alg_cache(alg::SROCKEM,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_proto
125130
else
126131
Gₛ₁ = zero(noise_rate_prototype)
127132
end
128-
WikRange = false .* vec(ΔW)
133+
if typeof(ΔW) <: Union{SArray,Number}
134+
WikRange = copy(ΔW)
135+
else
136+
WikRange = false .* vec(ΔW)
137+
end
129138
tmp = zero(u) # these 3 variables are dummied to use same memory
130139
fsalfirst = k
131140
atmp = zero(rate_prototype)
@@ -164,7 +173,11 @@ function alg_cache(alg::SKSROCK,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_proto
164173
Gₛ = zero(noise_rate_prototype)
165174
tmp = uᵢ₋₂ # Dummmy variables
166175
fsalfirst = k
167-
WikRange = false .* vec(ΔW)
176+
if typeof(ΔW) <: Union{SArray,Number}
177+
WikRange = copy(ΔW)
178+
else
179+
WikRange = false .* vec(ΔW)
180+
end
168181
atmp = zero(rate_prototype)
169182
constantcache = SKSROCKConstantCache{typeof(t)}(u)
170183
SKSROCKCache(u,uprev,uᵢ₋₁,uᵢ₋₂,Gₛ,tmp,k,fsalfirst,WikRange,atmp,constantcache)

src/perform_step/SROCK_perform_step.jl

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ end
198198
σ = (1-α)*1//2 + α*mσ[deg_index]
199199
τ = 1//2*((1-α)^2) + 2*α*(1-α)*mσ[deg_index] +^2)*(mσ[deg_index]*(mσ[deg_index]+mτ[deg_index]))
200200

201-
sqrt_dt = sqrt(dt)
201+
sqrt_dt = sqrt(abs(dt))
202202

203203
μ = recf[start] # here κ = 0
204204
tᵢ = t + α*dt*μ
@@ -258,7 +258,7 @@ end
258258

259259
uₓ = integrator.f(uₓ,p,tₓ)
260260
u += (1//2*dt)*uₓ
261-
uₓ = 1//2 .* Gₛ .* (W.dW .^ 2 .- dt)
261+
uₓ = 1//2 .* Gₛ .* (W.dW .^ 2 .- abs(dt))
262262
uᵢ₋₂ = uᵢ + uₓ
263263
Gₛ₁ = integrator.g(uᵢ₋₂,p,tᵢ)
264264
u += (1//2)*Gₛ₁
@@ -284,7 +284,7 @@ end
284284
u += (1//2)*dt*uₓ
285285
for i in 1:length(W.dW)
286286
WikJ = W.dW[i]; WikJ2 = vec_χ[i]
287-
WikRange = 1//2 .* (W.dW .* WikJ .- (1:length(W.dW) .== i) .* dt) #.- (1:length(W.dW) .> i) .* dt .* vec_χ .+ (1:length(W.dW) .< i) .* dt .* WikJ2)
287+
WikRange = 1//2 .* (W.dW .* WikJ .- (1:length(W.dW) .== i) .* abs(dt)) #.- (1:length(W.dW) .> i) .* dt .* vec_χ .+ (1:length(W.dW) .< i) .* dt .* WikJ2)
288288
uₓ = Gₛ*WikRange
289289
WikRange = 1//2 .* (1:length(W.dW) .== i)
290290
uᵢ₋₂ = uᵢ + uₓ
@@ -330,7 +330,7 @@ end
330330
σ = (1-α)*1//2 + α*mσ[deg_index]
331331
τ = 1//2*((1-α)^2) + 2*α*(1-α)*mσ[deg_index] +^2)*(mσ[deg_index]*(mσ[deg_index]+mτ[deg_index]))
332332

333-
sqrt_dt = sqrt(dt)
333+
sqrt_dt = sqrt(abs(dt))
334334
if gen_prob
335335
vec_χ .= 1//2 .+ oftype(W.dW, rand(W.rng, length(W.dW)))
336336
@.. vec_χ = 2*floor(vec_χ) - 1
@@ -396,7 +396,7 @@ end
396396
integrator.f(k,uₓ,p,tₓ)
397397
@.. u += (1//2)*dt*k
398398

399-
@.. uₓ = Gₛ*((W.dW^2 - dt)/2)
399+
@.. uₓ = Gₛ*((W.dW^2 - abs(dt))/2)
400400
@.. uᵢ₋₂ = uᵢ + uₓ
401401
integrator.g(Gₛ₁,uᵢ₋₂,p,tᵢ)
402402
@.. u += (1//2)*Gₛ₁
@@ -425,7 +425,7 @@ end
425425

426426
for i in 1:length(W.dW)
427427
WikJ = W.dW[i]; WikJ2 = vec_χ[i]
428-
WikRange .= 1//2 .* (W.dW .* WikJ .- (1:length(W.dW) .== i) .* dt )#.+ (1:length(W.dW) .< i) .* dt .* WikJ2 .- (1:length(W.dW) .> i) .* dt .* vec_χ)
428+
WikRange .= 1//2 .* (W.dW .* WikJ .- (1:length(W.dW) .== i) .* abs(dt) )#.+ (1:length(W.dW) .< i) .* dt .* WikJ2 .- (1:length(W.dW) .> i) .* dt .* vec_χ)
429429
mul!(uₓ,Gₛ,WikRange)
430430
@.. uᵢ₋₂ = uᵢ + uₓ
431431
WikRange .= 1//2 .* (1:length(W.dW) .== i)
@@ -517,7 +517,7 @@ end
517517

518518
if integrator.alg.strong_order_1
519519
if (typeof(W.dW) <: Number) || (length(W.dW) == 1) || (is_diagonal_noise(integrator.sol.prob))
520-
uᵢ₋₂ = @. 1//2 * Gₛ * (W.dW ^ 2 - dt)
520+
uᵢ₋₂ = @. 1//2 * Gₛ * (W.dW ^ 2 - abs(dt))
521521
tmp = @. u + uᵢ₋₂
522522
Gₛ = integrator.g(tmp,p,tᵢ)
523523
uᵢ₋₁ = @. 1//2*Gₛ
@@ -528,7 +528,7 @@ end
528528
else
529529
for i in 1:length(W.dW)
530530
WikJ = W.dW[i]
531-
WikRange = 1//2 .* (W.dW .* WikJ - (1:length(W.dW) .== i) .* dt)
531+
WikRange = 1//2 .* (W.dW .* WikJ - (1:length(W.dW) .== i) .* abs(dt))
532532
uᵢ₋₂ = Gₛ*WikRange
533533
WikRange = 1//2 .* (1:length(W.dW) .== i)
534534
tmp = u + uᵢ₋₂
@@ -608,7 +608,7 @@ end
608608

609609
if integrator.alg.strong_order_1
610610
if (typeof(W.dW) <: Number) || (length(W.dW) == 1) || (is_diagonal_noise(integrator.sol.prob))
611-
@.. uᵢ₋₂ = 1//2*Gₛ*(W.dW^2 - dt)
611+
@.. uᵢ₋₂ = 1//2*Gₛ*(W.dW^2 - abs(dt))
612612
@.. tmp = u + uᵢ₋₂
613613
integrator.g(Gₛ,tmp,p,tᵢ)
614614
@.. uᵢ₋₁ = 1//2*Gₛ
@@ -619,7 +619,7 @@ end
619619
else
620620
for i in 1:length(W.dW)
621621
WikJ = W.dW[i]
622-
WikRange .= 1//2 .* (WikJ .* W.dW .- (1:length(W.dW) .== i) .* dt)
622+
WikRange .= 1//2 .* (WikJ .* W.dW .- (1:length(W.dW) .== i) .* abs(dt))
623623
mul!(uᵢ₋₂,Gₛ,WikRange)
624624
WikRange .= 1//2 .* (1:length(W.dW) .== i)
625625
@.. tmp = u + uᵢ₋₂
@@ -834,7 +834,7 @@ end
834834

835835
η₁ = (rand() < 1//2) ? -1 : 1
836836
η₂ = (rand() < 1//2) ? -1 : 1
837-
sqrt_dt = sqrt(dt)
837+
sqrt_dt = sqrt(abs(dt))
838838

839839
Û₁ = zero(u)
840840
Û₂ = zero(u)
@@ -901,15 +901,15 @@ end
901901
uₓ += Gₛ*W.dW
902902

903903
uₓ = integrator.f(uₓ,p,tₓ)
904-
u += (1//2*dt)*uₓ + Gₛ*((W.dW^2 - dt)/(η₁*sqrt_dt) - W.dW)
904+
u += (1//2*dt)*uₓ + Gₛ*((W.dW^2 - abs(dt))/(η₁*sqrt_dt) - W.dW)
905905
Û₁ -= (η₁*sqrt_dt/2)*Gₛ
906906
Û₂ += (η₁*sqrt_dt/2)*Gₛ
907907

908908
Gₛ = integrator.g(Û₂,p,t̂₂)
909909
u += Gₛ*W.dW
910910

911911
Gₛ = integrator.g(Û₁,p,t̂₁)
912-
u += Gₛ*(W.dW - (W.dW^2 - dt)/(η₁*sqrt_dt))
912+
u += Gₛ*(W.dW - (W.dW^2 - abs(dt))/(η₁*sqrt_dt))
913913
elseif is_diagonal_noise(integrator.sol.prob)
914914

915915
Gₛ = integrator.g(Û₁,p,t̂₁)
@@ -922,13 +922,13 @@ end
922922
uₓ = integrator.f(uₓ,p,tₓ)
923923
u += (1//2)*dt*uₓ
924924

925-
u .+= Gₛ .* ((W.dW .^ 2 .- dt) ./ (η₁*sqrt_dt) .- W.dW)
925+
u .+= Gₛ .* ((W.dW .^ 2 .- abs(dt)) ./ (η₁*sqrt_dt) .- W.dW)
926926

927927
Gₛ = integrator.g(Û₂,p,t̂₂)
928928
u .+= Gₛ .* W.dW
929929

930930
Gₛ = integrator.g(Û₁,p,t̂₁)
931-
u .-= Gₛ .* ((W.dW .^ 2 .- dt) ./ (η₁*sqrt_dt) .- W.dW)
931+
u .-= Gₛ .* ((W.dW .^ 2 .- abs(dt)) ./ (η₁*sqrt_dt) .- W.dW)
932932
else
933933
Gₛ = integrator.g(Û₁,p,t̂₁)
934934

@@ -945,7 +945,7 @@ end
945945
for i in 1:length(W.dW)
946946
uᵢ₋₁ = Û₁ - (1//2*η₁*sqrt_dt)*@view(Gₛ[:,i])
947947
Gₛ₁ = integrator.g(uᵢ₋₁,p,t̂₁)
948-
u += @view(Gₛ₁[:,i])*W.dW[i] + (@view(Gₛ[:,i]) - @view(Gₛ₁[:,i]))*((W.dW[i]^2 - dt)/(η₁*sqrt_dt))
948+
u += @view(Gₛ₁[:,i])*W.dW[i] + (@view(Gₛ[:,i]) - @view(Gₛ₁[:,i]))*((W.dW[i]^2 - abs(dt))/(η₁*sqrt_dt))
949949
end
950950

951951
for i in 1:length(W.dW)
@@ -988,7 +988,7 @@ end
988988

989989
η₁ = (rand() < 1//2) ? -1 : 1
990990
η₂ = (rand() < 1//2) ? -1 : 1
991-
sqrt_dt = sqrt(dt)
991+
sqrt_dt = sqrt(abs(dt))
992992

993993
@.. Û₁ = zero(eltype(u))
994994
@.. Û₂ = zero(eltype(u))
@@ -1056,15 +1056,15 @@ end
10561056
@.. uₓ += Gₛ*W.dW
10571057

10581058
integrator.f(k,uₓ,p,tₓ)
1059-
@.. u += (1//2*dt)*k + Gₛ*((W.dW^2 - dt)/(η₁*sqrt_dt) - W.dW)
1059+
@.. u += (1//2*dt)*k + Gₛ*((W.dW^2 - abs(dt))/(η₁*sqrt_dt) - W.dW)
10601060
@.. Û₁ -= (η₁*sqrt_dt/2)*Gₛ
10611061
@.. Û₂ += (η₁*sqrt_dt/2)*Gₛ
10621062

10631063
integrator.g(Gₛ,Û₂,p,t̂₂)
10641064
@.. u += Gₛ*W.dW
10651065

10661066
integrator.g(Gₛ,Û₁,p,t̂₁)
1067-
@.. u += Gₛ*(W.dW - (W.dW^2 - dt)/(η₁*sqrt_dt))
1067+
@.. u += Gₛ*(W.dW - (W.dW^2 - abs(dt))/(η₁*sqrt_dt))
10681068
else
10691069
integrator.g(Gₛ,Û₁,p,t̂₁)
10701070

@@ -1081,7 +1081,7 @@ end
10811081
for i in 1:length(W.dW)
10821082
@.. uᵢ₋₁ = Û₁ - (1//2*η₁*sqrt_dt)*@view(Gₛ[:,i])
10831083
integrator.g(Gₛ₁,uᵢ₋₁,p,t̂₁)
1084-
@.. u += @view(Gₛ₁[:,i])*W.dW[i] + (@view(Gₛ[:,i]) - @view(Gₛ₁[:,i]))*((W.dW[i]^2 - dt)/(η₁*sqrt_dt))
1084+
@.. u += @view(Gₛ₁[:,i])*W.dW[i] + (@view(Gₛ[:,i]) - @view(Gₛ₁[:,i]))*((W.dW[i]^2 - abs(dt))/(η₁*sqrt_dt))
10851085
end
10861086

10871087
for i in 1:length(W.dW)
@@ -1120,7 +1120,7 @@ end
11201120
σ = mσ[deg_index]
11211121
τ = mτ[deg_index]
11221122

1123-
sqrt_dt = sqrt(dt)
1123+
sqrt_dt = sqrt(abs(dt))
11241124
(gen_prob) && (vec_χ = 2 .* floor.( 1//2 .+ false .* W.dW .+ rand(length(W.dW))) .- 1)
11251125

11261126
tᵢ₋₂ = t
@@ -1309,7 +1309,7 @@ end
13091309
σ = mσ[deg_index]
13101310
τ = mτ[deg_index]
13111311

1312-
sqrt_dt = sqrt(dt)
1312+
sqrt_dt = sqrt(abs(dt))
13131313
(gen_prob) && (vec_χ .= 2 .* floor.(1//2 .+ false .* vec_χ .+ rand(length(vec_χ))) .- 1)
13141314

13151315
tᵢ₋₂ = t

src/perform_step/sri.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ end
6262

6363
sqrt3 = sqrt(3one(eltype(W.dW)))
6464
if typeof(W.dW) <: Union{SArray,Number}
65-
chi1 = (W.dW.^2 - dt)/2integrator.sqdt #I_(1,1)/sqrt(h)
65+
chi1 = (W.dW.^2 - abs(dt))/2integrator.sqdt #I_(1,1)/sqrt(h)
6666
chi2 = (W.dW + W.dZ/sqrt3)/2 #I_(1,0)/h
6767
chi3 = (W.dW.^3 - 3W.dW*dt)/6dt #I_(1,1,1)/h
6868
else
@@ -169,7 +169,7 @@ end
169169

170170
sqrt3 = sqrt(3one(eltype(W.dW)))
171171
if typeof(W.dW) <: Union{SArray,Number}
172-
chi1 = (W.dW.^2 - dt)/2integrator.sqdt #I_(1,1)/sqrt(h)
172+
chi1 = (W.dW.^2 - abs(dt))/2integrator.sqdt #I_(1,1)/sqrt(h)
173173
chi2 = (W.dW + W.dZ/sqrt3)/2 #I_(1,0)/h
174174
chi3 = (W.dW.^3 - 3W.dW*dt)/6dt #I_(1,1,1)/h
175175
else
@@ -221,7 +221,7 @@ end
221221
@muladd function perform_step!(integrator,cache::SRIW1ConstantCache)
222222
@unpack t,dt,uprev,u,W,p,f = integrator
223223
sqrt3 = sqrt(3one(eltype(W.dW)))
224-
chi1 = @.. (W.dW.^2 - dt)/2integrator.sqdt #I_(1,1)/sqrt(h)
224+
chi1 = @.. (W.dW.^2 - abs(dt))/2integrator.sqdt #I_(1,1)/sqrt(h)
225225
chi2 = @.. (W.dW + W.dZ/sqrt3)/2 #I_(1,0)/h
226226
chi3 = @.. (W.dW.^3 - 3W.dW*dt)/6dt #I_(1,1,1)/h
227227
fH01 = dt*integrator.f(uprev,p,t)
@@ -264,12 +264,12 @@ end
264264
@unpack t,dt,uprev,u,W,p,f = integrator
265265
@unpack c₀,c₁,A₀,A₁,B₀,B₁,α,β₁,β₂,β₃,β₄,stages,H0,H1,error_terms = cache
266266
sqrt3 = sqrt(3one(eltype(W.dW)))
267-
chi1 = .5*(W.dW.^2 - dt)/integrator.sqdt #I_(1,1)/sqrt(h)
267+
chi1 = .5*(W.dW.^2 - abs(dt))/integrator.sqdt #I_(1,1)/sqrt(h)
268268
chi2 = .5*(W.dW + W.dZ/sqrt3) #I_(1,0)/h
269269
chi3 = 1/6 * (W.dW.^3 - 3*W.dW*dt)/dt #I_(1,1,1)/h
270270

271-
fill!(H0,zero(typeof(u)))
272-
fill!(H1,zero(typeof(u)))
271+
fill!(H0,zero(u))
272+
fill!(H1,zero(u))
273273
@inbounds for i in 1:stages
274274
A0temp = zero(u)
275275
B0temp = zero(u)
@@ -321,7 +321,7 @@ end
321321
@unpack a021,a031,a032,a041,a042,a043,a121,a131,a132,a141,a142,a143,b021,b031,b032,b041,b042,b043,b121,b131,b132,b141,b142,b143,c02,c03,c04,c11,c12,c13,c14,α1,α2,α3,α4,beta11,beta12,beta13,beta14,beta21,beta22,beta23,beta24,beta31,beta32,beta33,beta34,beta41,beta42,beta43,beta44 = cache
322322

323323
sqrt3 = sqrt(3one(eltype(W.dW)))
324-
chi1 = (W.dW.^2 .- dt)/(2integrator.sqdt) #I_(1,1)/sqrt(h)
324+
chi1 = (W.dW.^2 .- abs(dt))/(2integrator.sqdt) #I_(1,1)/sqrt(h)
325325
chi2 = (W.dW .+ W.dZ./sqrt3)./2 #I_(1,0)/h
326326
chi3 = (W.dW.^3 .- 3*W.dW*dt)/(6dt) #I_(1,1,1)/h
327327

@@ -378,7 +378,7 @@ end
378378
sqdt = integrator.sqdt
379379
sqrt3 = sqrt(3one(eltype(W.dW)))
380380
if typeof(W.dW) <: Union{SArray,Number}
381-
chi1 = (W.dW.^2 - dt)/2integrator.sqdt #I_(1,1)/sqrt(h)
381+
chi1 = (W.dW.^2 - abs(dt))/2integrator.sqdt #I_(1,1)/sqrt(h)
382382
chi2 = (W.dW + W.dZ/sqrt3)/2 #I_(1,0)/h
383383
chi3 = (W.dW.^3 - 3W.dW*dt)/6dt #I_(1,1,1)/h
384384
else

test/reversal_tests.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,17 @@ Ito_solver = [
4343
WangLi3SMil_E(),
4444
WangLi3SMil_F(),
4545
RKMil(),
46+
SRI(),
47+
SRIW1(),
48+
SRIW2(),
49+
SOSRI(),
50+
SOSRI2(),
4651
# S-Rock methods
4752
SROCK1(),
53+
SROCKEM(),
54+
SROCK2(),
55+
SKSROCK(),
56+
TangXiaoSROCK2(),
4857
# stiff methods
4958
ImplicitEM(),
5059
ImplicitRKMil(),

0 commit comments

Comments
 (0)