Skip to content

Commit 6b63618

Browse files
Merge pull request #492 from rmsrosa/rode_solver
RODE solver RandomHeun()
2 parents 41400cc + cb575ce commit 6b63618

File tree

6 files changed

+69
-7
lines changed

6 files changed

+69
-7
lines changed

src/StochasticDiffEq.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ using DocStringExtensions
166166
StochasticDiffEqRODECompositeAlgorithm
167167

168168
export RandomEM
169+
170+
export RandomHeun
169171

170172
export IteratedIntegralApprox, IICommutative, IILevyArea
171173

src/alg_utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ alg_order(alg::IIF1Mil) = 1 // 1
5959
alg_order(alg::EulerHeun) = 1 // 2
6060
alg_order(alg::LambaEulerHeun) = 1 // 2
6161
alg_order(alg::RandomEM) = 1 // 2
62+
alg_order(alg::RandomHeun) = 1 // 2
6263
alg_order(alg::SimplifiedEM) = 1 // 2
6364
alg_order(alg::RKMil) = 1 // 1
6465
alg_order(alg::RKMilCommute) = 1 // 1

src/algorithms.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,8 @@ end
848848

849849
struct RandomEM <: StochasticDiffEqRODEAlgorithm end
850850

851+
struct RandomHeun <: StochasticDiffEqRODEAlgorithm end
852+
851853
const SplitSDEAlgorithms = Union{IIF1M,IIF2M,IIF1Mil,SKenCarp,SplitEM}
852854

853855
@doc raw"""

src/caches/basic_method_caches.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,23 @@ function alg_cache(alg::RandomEM,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prot
7272
RandomEMCache(u,uprev,tmp,rtmp)
7373
end
7474

75+
struct RandomHeunConstantCache <: StochasticDiffEqConstantCache end
76+
@cache struct RandomHeunCache{uType,rateType,randType} <: StochasticDiffEqMutableCache
77+
u::uType
78+
uprev::uType
79+
tmp::uType
80+
rtmp1::rateType
81+
rtmp2::rateType
82+
wtmp::randType
83+
end
84+
85+
alg_cache(alg::RandomHeun,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{false}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits} = RandomHeunConstantCache()
86+
87+
function alg_cache(alg::RandomHeun,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{true}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
88+
tmp = zero(u); rtmp1 = zero(rate_prototype); rtmp2 = zero(rate_prototype); wtmp = zero(ΔW)
89+
RandomHeunCache(u,uprev,tmp,rtmp1,rtmp2,wtmp)
90+
end
91+
7592
struct SimplifiedEMConstantCache <: StochasticDiffEqConstantCache end
7693
@cache struct SimplifiedEMCache{randType,uType,rateType,rateNoiseType} <: StochasticDiffEqMutableCache
7794
u::uType

src/perform_step/low_order.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,29 @@ end
121121
@.. u = uprev + dt * rtmp
122122
end
123123

124+
@muladd function perform_step!(integrator,cache::RandomHeunConstantCache)
125+
@unpack t,dt,uprev,u,W,p,f = integrator
126+
ftmp = integrator.f(uprev,p,t,W.curW)
127+
tmp = @.. uprev + dt * ftmp
128+
wtmp = @.. W.curW + W.dW
129+
u = uprev .+ (dt/2) .* (ftmp .+ integrator.f(tmp,p,t+dt, wtmp))
130+
integrator.u = u
131+
end
132+
133+
@muladd function perform_step!(integrator,cache::RandomHeunCache)
134+
@unpack tmp, rtmp1, rtmp2, wtmp = cache
135+
@unpack t,dt,uprev,u,W,p,f = integrator
136+
integrator.f(rtmp1,uprev,p,t,W.curW)
137+
@.. tmp = uprev + dt * rtmp1
138+
if W.dW isa Number
139+
wtmp = W.curW + W.dW
140+
else
141+
@.. wtmp = W.curW + W.dW
142+
end
143+
integrator.f(rtmp2,tmp,p,t+dt,wtmp)
144+
@.. u = uprev + (dt/2) * (rtmp1 + rtmp2)
145+
end
146+
124147
# weak approximation EM
125148
@muladd function perform_step!(integrator,cache::SimplifiedEMConstantCache)
126149
@unpack t,dt,uprev,u,W,p,f = integrator

test/rode_linear_tests.jl

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,44 @@
1-
using StochasticDiffEq
1+
using StochasticDiffEq, DiffEqNoiseProcess, Random, Test
2+
Random.seed!(100)
23

34
f(u,p,t,W) = 1.01u.+0.87u.*W
45
u0 = 1.00
56
tspan = (0.0,1.0)
67
prob = RODEProblem(f,u0,tspan)
7-
sol = solve(prob,RandomEM(),dt=1/100)
8+
sol = solve(prob,RandomEM(),dt=1/100, save_noise=true)
9+
prob = RODEProblem(f,u0,tspan, noise=NoiseWrapper(sol.W))
10+
sol2 = solve(prob,RandomHeun(),dt=1/100)
11+
@test abs(sol[end] - sol2[end]) < 0.1
12+
813

914
f(du,u,p,t,W) = (du.=1.01u.+0.87u.*W)
1015
u0 = ones(4)
1116
prob = RODEProblem(f,u0,tspan)
12-
sol = solve(prob,RandomEM(),dt=1/100)
17+
sol = solve(prob,RandomEM(),dt=1/100, save_noise=true)
18+
prob = RODEProblem(f,u0,tspan, noise=NoiseWrapper(sol.W))
19+
sol2 = solve(prob,RandomHeun(),dt=1/100)
20+
@test sum(abs,sol[end]-sol2[end]) < 0.1 * sum(abs, sol[end])
1321

1422
f(u,p,t,W) = 2u*sin(W)
1523
u0 = 1.00
1624
tspan = (0.0,5.0)
1725
prob = RODEProblem{false}(f,u0,tspan)
18-
sol = solve(prob,RandomEM(),dt=1/100)
26+
sol = solve(prob,RandomEM(),dt=1/100, save_noise=true)
27+
prob = RODEProblem{false}(f,u0,tspan, noise=NoiseWrapper(sol.W))
28+
sol2 = solve(prob,RandomHeun(),dt=1/100)
29+
@test abs(sol[end]-sol2[end]) < 0.1
1930

2031
function f(du,u,p,t,W)
2132
du[1] = 2u[1]*sin(W[1] - W[2])
2233
du[2] = -2u[2]*cos(W[1] + W[2])
2334
end
2435
u0 = [1.00;1.00]
25-
tspan = (0.0,5.0)
36+
tspan = (0.0,4.0)
2637
prob = RODEProblem(f,u0,tspan)
27-
sol = solve(prob,RandomEM(),dt=1/100)
38+
sol = solve(prob,RandomEM(),dt=1/100, save_noise=true)
39+
prob = RODEProblem(f,u0,tspan, noise=NoiseWrapper(sol.W))
40+
sol2 = solve(prob,RandomHeun(),dt=1/100)
41+
@test sum(abs,sol[end]-sol2[end]) < 0.1 * sum(abs, sol[end])
2842

2943
function f(du,u,p,t,W)
3044
du[1] = -2W[3]*u[1]*sin(W[1] - W[2])
@@ -33,4 +47,7 @@ end
3347
u0 = [1.00;1.00]
3448
tspan = (0.0,5.0)
3549
prob = RODEProblem(f,u0,tspan,rand_prototype=zeros(3))
36-
sol = solve(prob,RandomEM(),dt=1/100)
50+
sol = solve(prob,RandomEM(),dt=1/100, save_noise=true)
51+
prob = RODEProblem(f,u0,tspan, noise=NoiseWrapper(sol.W))
52+
sol2 = solve(prob,RandomHeun(),dt=1/100)
53+
@test sum(abs,sol[end]-sol2[end]) < 0.1 * sum(abs, sol[end])

0 commit comments

Comments
 (0)