Skip to content

Commit 914a509

Browse files
Merge pull request #501 from rmsrosa/random_tamed_em
add RandomTamedEM() solver for RODEs
2 parents 7862ea0 + 75653fc commit 914a509

File tree

6 files changed

+53
-13
lines changed

6 files changed

+53
-13
lines changed

src/StochasticDiffEq.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,7 @@ using DocStringExtensions
165165
export StochasticDiffEqRODEAlgorithm, StochasticDiffEqRODEAdaptiveAlgorithm,
166166
StochasticDiffEqRODECompositeAlgorithm
167167

168-
export RandomEM
169-
170-
export RandomHeun
168+
export RandomEM, RandomTamedEM, RandomHeun
171169

172170
export IteratedIntegralApprox, IICommutative, IILevyArea
173171

src/alg_utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ alg_order(alg::EulerHeun) = 1 // 2
6060
alg_order(alg::LambaEulerHeun) = 1 // 2
6161
alg_order(alg::RandomEM) = 1 // 2
6262
alg_order(alg::RandomHeun) = 1 // 2
63+
alg_order(alg::RandomTamedEM) = 1 // 2
6364
alg_order(alg::SimplifiedEM) = 1 // 2
6465
alg_order(alg::RKMil) = 1 // 1
6566
alg_order(alg::RKMilCommute) = 1 // 1

src/algorithms.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,8 @@ struct RandomEM <: StochasticDiffEqRODEAlgorithm end
850850

851851
struct RandomHeun <: StochasticDiffEqRODEAlgorithm end
852852

853+
struct RandomTamedEM <: StochasticDiffEqRODEAlgorithm end
854+
853855
const SplitSDEAlgorithms = Union{IIF1M,IIF2M,IIF1Mil,SKenCarp,SplitEM}
854856

855857
@doc raw"""

src/caches/basic_method_caches.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,22 @@ function alg_cache(alg::RandomEM,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prot
7272
RandomEMCache(u,uprev,tmp,rtmp)
7373
end
7474

75+
struct RandomTamedEMConstantCache <: StochasticDiffEqConstantCache end
76+
77+
@cache struct RandomTamedEMCache{uType,rateType} <: StochasticDiffEqMutableCache
78+
u::uType
79+
uprev::uType
80+
tmp::uType
81+
rtmp::rateType
82+
end
83+
84+
alg_cache(alg::RandomTamedEM,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{false}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits} = RandomTamedEMConstantCache()
85+
86+
function alg_cache(alg::RandomTamedEM,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{true}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
87+
tmp = zero(u); rtmp = zero(rate_prototype)
88+
RandomTamedEMCache(u,uprev,tmp,rtmp)
89+
end
90+
7591
struct RandomHeunConstantCache <: StochasticDiffEqConstantCache end
7692
@cache struct RandomHeunCache{uType,rateType,randType} <: StochasticDiffEqMutableCache
7793
u::uType

src/perform_step/low_order.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,20 @@ end
121121
@.. u = uprev + dt * rtmp
122122
end
123123

124+
@muladd function perform_step!(integrator,cache::RandomTamedEMConstantCache)
125+
@unpack t,dt,uprev,u,W,p,f = integrator
126+
ftmp = integrator.f(uprev,p,t,W.curW)
127+
u = uprev .+ dt .* ftmp ./ (1 .+ dt .* norm(ftmp))
128+
integrator.u = u
129+
end
130+
131+
@muladd function perform_step!(integrator,cache::RandomTamedEMCache)
132+
@unpack rtmp = cache
133+
@unpack t,dt,uprev,u,W,p,f = integrator
134+
integrator.f(rtmp,uprev,p,t,W.curW)
135+
@.. u = uprev + dt * rtmp / (1 + dt * norm(rtmp))
136+
end
137+
124138
@muladd function perform_step!(integrator,cache::RandomHeunConstantCache)
125139
@unpack t,dt,uprev,u,W,p,f = integrator
126140
ftmp = integrator.f(uprev,p,t,W.curW)

test/rode_linear_tests.jl

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ prob = RODEProblem(f,u0,tspan)
88
sol = solve(prob,RandomEM(),dt=1/100, save_noise=true)
99
prob = RODEProblem(f,u0,tspan, noise=NoiseWrapper(sol.W))
1010
sol2 = solve(prob,RandomHeun(),dt=1/100)
11-
@test abs(sol[end] - sol2[end]) < 0.1
12-
11+
@test abs(sol[end] - sol2[end]) < 0.1 * abs(sol[end])
12+
sol3 = solve(prob,RandomTamedEM(),dt=1/100)
13+
@test abs(sol[end] - sol3[end]) < 0.1 * abs(sol[end])
1314

1415
f(du,u,p,t,W) = (du.=1.01u.+0.87u.*W)
1516
u0 = ones(4)
@@ -18,36 +19,44 @@ sol = solve(prob,RandomEM(),dt=1/100, save_noise=true)
1819
prob = RODEProblem(f,u0,tspan, noise=NoiseWrapper(sol.W))
1920
sol2 = solve(prob,RandomHeun(),dt=1/100)
2021
@test sum(abs,sol[end]-sol2[end]) < 0.1 * sum(abs, sol[end])
22+
sol3 = solve(prob,RandomEM(),dt=1/100)
23+
@test sum(abs,sol[end]-sol3[end]) < 0.1 * sum(abs, sol[end])
2124

2225
f(u,p,t,W) = 2u*sin(W)
2326
u0 = 1.00
24-
tspan = (0.0,5.0)
27+
tspan = (0.0,1.0)
2528
prob = RODEProblem{false}(f,u0,tspan)
2629
sol = solve(prob,RandomEM(),dt=1/100, save_noise=true)
2730
prob = RODEProblem{false}(f,u0,tspan, noise=NoiseWrapper(sol.W))
2831
sol2 = solve(prob,RandomHeun(),dt=1/100)
29-
@test abs(sol[end]-sol2[end]) < 0.1
32+
@test abs(sol[end]-sol2[end]) < 0.1 * abs(sol[end])
33+
sol3 = solve(prob,RandomTamedEM(),dt=1/100)
34+
@test abs(sol[end]-sol3[end]) < 0.1 * abs(sol[end])
3035

3136
function f(du,u,p,t,W)
32-
du[1] = 2u[1]*sin(W[1] - W[2])
33-
du[2] = -2u[2]*cos(W[1] + W[2])
37+
du[1] = 0.2u[1]*sin(W[1] - W[2])
38+
du[2] = -0.2u[2]*cos(W[1] + W[2])
3439
end
3540
u0 = [1.00;1.00]
36-
tspan = (0.0,4.0)
41+
tspan = (0.0,1.0)
3742
prob = RODEProblem(f,u0,tspan)
3843
sol = solve(prob,RandomEM(),dt=1/100, save_noise=true)
3944
prob = RODEProblem(f,u0,tspan, noise=NoiseWrapper(sol.W))
4045
sol2 = solve(prob,RandomHeun(),dt=1/100)
4146
@test sum(abs,sol[end]-sol2[end]) < 0.1 * sum(abs, sol[end])
47+
sol3 = solve(prob,RandomTamedEM(),dt=1/100)
48+
@test sum(abs,sol[end]-sol3[end]) < 0.1 * sum(abs, sol[end])
4249

4350
function f(du,u,p,t,W)
44-
du[1] = -2W[3]*u[1]*sin(W[1] - W[2])
45-
du[2] = -2u[2]*cos(W[1] + W[2])
51+
du[1] = -0.2W[3]*u[1]*sin(W[1] - W[2])
52+
du[2] = -0.2u[2]*cos(W[1] + W[2])
4653
end
4754
u0 = [1.00;1.00]
48-
tspan = (0.0,5.0)
55+
tspan = (0.0,1.0)
4956
prob = RODEProblem(f,u0,tspan,rand_prototype=zeros(3))
5057
sol = solve(prob,RandomEM(),dt=1/100, save_noise=true)
5158
prob = RODEProblem(f,u0,tspan, noise=NoiseWrapper(sol.W))
5259
sol2 = solve(prob,RandomHeun(),dt=1/100)
5360
@test sum(abs,sol[end]-sol2[end]) < 0.1 * sum(abs, sol[end])
61+
sol3 = solve(prob,RandomTamedEM(),dt=1/100)
62+
@test sum(abs,sol[end]-sol3[end]) < 0.1 * sum(abs, sol3[end])

0 commit comments

Comments
 (0)