Skip to content

Commit d949ca1

Browse files
committed
remove alloc in perform_step for nonscalar noise
1 parent cf5c12b commit d949ca1

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

src/caches/basic_method_caches.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,19 +73,20 @@ function alg_cache(alg::RandomEM,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prot
7373
end
7474

7575
struct RandomHeunConstantCache <: StochasticDiffEqConstantCache end
76-
@cache struct RandomHeunCache{uType,rateType} <: StochasticDiffEqMutableCache
76+
@cache struct RandomHeunCache{uType,rateType,randType} <: StochasticDiffEqMutableCache
7777
u::uType
7878
uprev::uType
7979
tmp::uType
8080
rtmp1::rateType
8181
rtmp2::rateType
82+
wtmp::randType
8283
end
8384

8485
alg_cache(alg::RandomHeun,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{false}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits} = RandomHeunConstantCache()
8586

8687
function alg_cache(alg::RandomHeun,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{true}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
87-
tmp = zero(u); rtmp1 = zero(rate_prototype); rtmp2 = zero(rate_prototype)
88-
RandomHeunCache(u,uprev,tmp,rtmp1,rtmp2)
88+
tmp = zero(u); rtmp1 = zero(rate_prototype); rtmp2 = zero(rate_prototype); wtmp = zero(ΔW)
89+
RandomHeunCache(u,uprev,tmp,rtmp1,rtmp2,wtmp)
8990
end
9091

9192
struct SimplifiedEMConstantCache <: StochasticDiffEqConstantCache end

src/perform_step/low_order.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,17 +124,23 @@ end
124124
@muladd function perform_step!(integrator,cache::RandomHeunConstantCache)
125125
@unpack t,dt,uprev,u,W,p,f = integrator
126126
ftmp = integrator.f(uprev,p,t,W.curW)
127-
tmp = @.. uprev + dt * ftmp
128-
u = uprev .+ (dt/2) .* (ftmp .+ integrator.f(tmp,p,t+dt, W.curW .+ W.dW))
127+
tmp = @.. uprev + dt * ftmp
128+
wtmp = @.. W.curW + W.dW
129+
u = uprev .+ (dt/2) .* (ftmp .+ integrator.f(tmp,p,t+dt, wtmp))
129130
integrator.u = u
130131
end
131132

132133
@muladd function perform_step!(integrator,cache::RandomHeunCache)
133-
@unpack tmp, rtmp1, rtmp2 = cache
134+
@unpack tmp, rtmp1, rtmp2, wtmp = cache
134135
@unpack t,dt,uprev,u,W,p,f = integrator
135136
integrator.f(rtmp1,uprev,p,t,W.curW)
136137
@.. tmp = uprev + dt * rtmp1
137-
integrator.f(rtmp2,tmp,p,t+dt,W.curW+W.dW)
138+
if W.dW isa Number
139+
wtmp = W.curW + W.dW
140+
else
141+
@.. wtmp = W.curW + W.dW
142+
end
143+
integrator.f(rtmp2,tmp,p,t+dt,wtmp)
138144
@.. u = uprev + (dt/2) * (rtmp1 + rtmp2)
139145
end
140146

0 commit comments

Comments
 (0)