Skip to content

Commit 801861d

Browse files
Merge pull request #496 from rmsrosa/noise_hygiene
deepcopy prob.noise
2 parents b34540b + 1ef9baf commit 801861d

File tree

5 files changed

+48
-4
lines changed

5 files changed

+48
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
3434
Adapt = "3"
3535
ArrayInterface = "2.4, 3.0, 4, 5, 6"
3636
DataStructures = "0.18"
37-
DiffEqBase = "6.19"
37+
DiffEqBase = "6.104"
3838
JumpProcesses = "9"
3939
DiffEqNoiseProcess = "5.13"
4040
DocStringExtensions = "0.8, 0.9"

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ functionality should check out [DifferentialEquations.jl](https://github.com/Sci
1111

1212
## API
1313

14-
StochasticDiffEq.jl is part of the JuliaDiffEq common interface, but can be used independently of DifferentialEquations.jl. The only requirement is that the user passes an StochasticDiffEq.jl algorithm to `solve`. For example, we can solve the [SDE tutorial from the docs](https://diffeq.sciml.ai/stable/tutorials/sde_example/) using the `SRIW1()` algorithm:
14+
StochasticDiffEq.jl is part of the JuliaDiffEq common interface, but can be used independently of [DifferentialEquations.jl](https://github.com/SciML/DifferentialEquations.jl). The only requirement is that the user passes an StochasticDiffEq.jl algorithm to `solve`. For example, we can solve the [SDE tutorial from the docs](https://diffeq.sciml.ai/stable/tutorials/sde_example/) using the `SRIW1()` algorithm:
1515

1616
```julia
1717
using StochasticDiffEq
@@ -66,7 +66,7 @@ end
6666
prob = SDEProblem(f,g,ones(2),(0.0,1.0),noise_rate_prototype=zeros(2,4))
6767
```
6868

69-
Colored noise can be set using [an `AbstractNoiseProcess`](https://diffeq.sciml.ai/stable/features/noise_process/). For example, we can set the underlying noise process to a `GeometricBrownian` via:
69+
Colored noise can be set using [an `AbstractNoiseProcess`](https://diffeq.sciml.ai/stable/features/noise_process/). For example, we can set the underlying noise process to a `GeometricBrownianMotionProcess` via:
7070

7171
```julia
7272
μ = 1.0

src/solve.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractRODEProblem,
55
kwargs...) where recompile_flag
66
integrator = DiffEqBase.__init(prob,alg,timeseries,ts,recompile;kwargs...)
77
solve!(integrator)
8+
if typeof(prob) <: DiffEqBase.AbstractRODEProblem && typeof(prob.noise) == typeof(integrator.sol.W) && (!haskey(kwargs, :alias_noise) || kwargs[:alias_noise] === true)
9+
# would be better to make the following a function `noise_deepcopy!(W::T, Z::T) where {T <: AbstractNoiseProcess}` in `DiffEqNoiseProcess.jl` or a proper `copy` overload, but this should do it for the moment
10+
for x in fieldnames(typeof(prob.noise))
11+
setfield!(prob.noise, x, deepcopy(getfield(integrator.sol.W, x)))
12+
end
13+
end
814
integrator.sol
915
end
1016

@@ -412,7 +418,7 @@ function DiffEqBase.__init(
412418
=#
413419
end
414420
elseif typeof(prob) <: DiffEqBase.AbstractRODEProblem
415-
W = prob.noise
421+
W = (!haskey(kwargs, :alias_noise) || kwargs[:alias_noise] === true) ? deepcopy(prob.noise) : prob.noise
416422
if W.reset
417423
# Reseed
418424
if typeof(W) <: Union{NoiseProcess, NoiseTransport} && W.reseed

test/noise_type_test.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,26 @@ sol = solve(prob,SRA())
3737

3838
@test length(sol.W[1]) == 4
3939

40+
f(du,u,p,t) = (du.=1.01u)
41+
g(du,u,p,t) = (du.=0.1)
42+
Z = WienerProcess(0.0, [0.0])
43+
prob = SDEProblem(f,g,[1.0],(0.0,1.0),noise=Z)
44+
45+
sol = solve(prob,EM(),dt=1/100)
46+
47+
@test sol.W == prob.noise
48+
@test objectid(prob.noise) != objectid(sol.W)
49+
50+
sol = solve(prob,EM(),dt=1/1000,alias_noise=false)
51+
52+
@test sol.W == prob.noise
53+
@test objectid(prob.noise) == objectid(sol.W)
54+
55+
sol = solve(prob,EM(),dt=1/1000, alias_noise=true)
56+
57+
@test sol.W == prob.noise
58+
@test objectid(prob.noise) != objectid(sol.W)
59+
4060
function g(du,u,p,t)
4161
@test typeof(du) <: SparseMatrixCSC
4262
du[1,1] = 0.3u[1]
@@ -72,3 +92,18 @@ tspan = (0.0,2.0)
7292
prob = SDEProblem(drift,vol,u0,tspan, noise=W)
7393
sol = solve(prob,EM(),dt=0.01)
7494
@test sol.W.curt last(tspan)
95+
96+
@test typeof(sol.W) == typeof(prob.noise) <: NoiseFunction
97+
@test objectid(prob.noise) != objectid(sol.W)
98+
99+
sol = solve(prob,EM(),dt=0.01,alias_noise=true)
100+
@test sol.W.curt last(tspan)
101+
102+
@test typeof(sol.W) == typeof(prob.noise) <: NoiseFunction
103+
@test objectid(prob.noise) != objectid(sol.W)
104+
105+
sol = solve(prob,EM(),dt=0.01,alias_noise=false)
106+
@test sol.W.curt last(tspan)
107+
108+
@test typeof(sol.W) == typeof(prob.noise) <: NoiseFunction
109+
@test objectid(prob.noise) == objectid(sol.W)

test/reversal_tests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ W_reverse = reverse(W_forward)
8585

8686
@test sol_forward(ts).u sol_reverse(ts).u rtol = 1e-3
8787
@test length(sol_forward.t) == length(sol_reverse.t)
88+
GC.gc()
8889
end
8990
end
9091

@@ -103,6 +104,7 @@ end
103104

104105
@test sol_forward(ts).u sol_reverse(ts).u rtol = 1e-2
105106
@test length(sol_forward.t) == length(sol_reverse.t)
107+
GC.gc()
106108
end
107109
end
108110

@@ -138,5 +140,6 @@ end
138140
@test sol_forward(ts).u sol_reverse(ts).u rtol = 1e-2
139141
end
140142
@test length(sol_forward.t) == length(sol_reverse.t)
143+
GC.gc()
141144
end
142145
end

0 commit comments

Comments
 (0)