Skip to content

Commit 23754a8

Browse files
Add more methods for EnsembleProblem and add more tests
1 parent 8c17212 commit 23754a8

File tree

5 files changed

+61
-0
lines changed

5 files changed

+61
-0
lines changed

src/ensemble/ensemble_problems.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,26 @@ function EnsembleProblem(; prob,
4444
EnsembleProblem(prob; prob_func, output_func, reduction, u_init, safetycopy)
4545
end
4646

47+
function EnsembleProblem(; prob,
48+
u0s::Union{Nothing, Vector{uType}} = nothing,
49+
prob_func = (prob, i, repeat) -> remake(prob, u0 = u0s[i]),
50+
output_func = DEFAULT_OUTPUT_FUNC,
51+
reduction = DEFAULT_REDUCTION,
52+
u_init = nothing, p = nothing,
53+
safetycopy = prob_func !== DEFAULT_PROB_FUNC) where {uType}
54+
EnsembleProblem(prob; prob_func, output_func, reduction, u_init, safetycopy)
55+
end
56+
57+
function EnsembleProblem(; prob,
58+
trajectories::Int,
59+
prob_func,
60+
output_func = DEFAULT_OUTPUT_FUNC,
61+
reduction = DEFAULT_REDUCTION,
62+
u_init = nothing, p = nothing,
63+
safetycopy = prob_func !== DEFAULT_PROB_FUNC)
64+
EnsembleProblem(prob; prob_func, output_func, reduction, u_init, safetycopy)
65+
end
66+
4767
struct WeightedEnsembleProblem{T1 <: AbstractEnsembleProblem, T2 <: AbstractVector} <:
4868
AbstractEnsembleProblem
4969
ensembleprob::T1

test/downstream/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d"
3+
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
34
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
45
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
56
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"

test/downstream/ensemble_diffeq.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
using DifferentialEquations
2+
3+
f(u, p, t) = 1.01 * u
4+
u0 = 1 / 2
5+
tspan = (0.0, 1.0)
6+
prob = ODEProblem(f, u0, tspan)
7+
ensemble_prob = EnsembleProblem(prob, prob_func = (prob, i, repeat) -> remake(prob, u0 = rand()))
8+
sim = solve(ensemble_prob, EnsembleThreads(), trajectories = 10, dt = 0.1)

test/downstream/ensemble_nondes.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using Optimization, OptimizationOptimJL, ForwardDiff, Test
2+
3+
x0 = zeros(2)
4+
rosenbrock(x, p = nothing) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2
5+
l1 = rosenbrock(x0)
6+
7+
optf = OptimizationFunction(rosenbrock, Optimization.AutoForwardDiff())
8+
prob = OptimizationProblem(optf, x0)
9+
sol1 = Optimization.solve(prob, OptimizationOptimJL.BFGS(), maxiters = 5)
10+
11+
ensembleprob = Optimization.EnsembleProblem(prob, [x0, x0 .+ rand(2), x0 .+ rand(2), x0 .+ rand(2)])
12+
13+
sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleThreads(), trajectories = 4, maxiters = 5)
14+
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective
15+
16+
sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleDistributed(), trajectories = 4, maxiters = 5)
17+
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective
18+
19+
prob = OptimizationProblem(optf, x0, lb = [-0.5, -0.5], ub = [0.5, 0.5])
20+
ensembleprob = Optimization.EnsembleProblem(prob, 5, prob_func = (prob, i, repeat) -> remake(prob, u0 = rand(-0.5:0.001:0.5, 2)))
21+
22+
sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleThreads(), trajectories = 5, maxiters = 5)
23+
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective
24+
25+
sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleDistributed(), trajectories = 5, maxiters = 5)
26+
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,12 @@ end
7878
@time @safetestset "Ensemble solution statistics" begin
7979
include("downstream/ensemble_stats.jl")
8080
end
81+
@time @safetestset "Ensemble Optimization and Nonlinear problems" begin
82+
include("downstream/ensemble_nondes.jl")
83+
end
84+
@time @safetestset "Ensemble with DifferentialEquations automatic algorithm selection" begin
85+
include("downstream/ensemble_diffeq.jl")
86+
end
8187
@time @safetestset "Symbol and integer based indexing of interpolated solutions" begin
8288
include("downstream/symbol_indexing.jl")
8389
end

0 commit comments

Comments
 (0)