Skip to content

Commit dea66d1

Browse files
Run optimization tests and add constructors for it
1 parent 7e7367b commit dea66d1

File tree

5 files changed

+35
-5
lines changed

5 files changed

+35
-5
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
2121
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2222
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
2323
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
24+
QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b"
2425
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
2526
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2627
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"

src/SciMLBase.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import ADTypes: AbstractADType
2525
import ChainRulesCore
2626
import ZygoteRules: @adjoint
2727
import FillArrays
28-
28+
import QuasiMonteCarlo
2929
using Reexport
3030
using SciMLOperators
3131
using SciMLOperators:

src/ensemble/ensemble_problems.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,23 @@ function EnsembleProblem(; prob,
4444
EnsembleProblem(prob; prob_func, output_func, reduction, u_init, safetycopy)
4545
end
4646

47+
#since NonlinearProblem might want to use this dispatch as well
48+
function SciMLBase.EnsembleProblem(prob::AbstractSciMLProblem, u0s::Vector{Vector{T}}; kwargs...) where {T}
49+
prob_func = (prob, i, repeat = nothing) -> remake(prob, u0 = u0s[i])
50+
return SciMLBase.EnsembleProblem(prob; prob_func, kwargs...)
51+
end
52+
53+
#only makes sense for OptimizationProblem, might make sense for IntervalNonlinearProblem
54+
function SciMLBase.EnsembleProblem(prob::OptimizationProblem, trajectories::Int; kwargs...)
55+
if prob.lb !== nothing && prob.ub !== nothing
56+
u0s = QuasiMonteCarlo.sample(trajectories, prob.lb, prob.ub, QuasiMonteCarlo.LatinHypercubeSample())
57+
prob_func = (prob, i, repeat = nothing) -> remake(prob, u0 = u0s[:, i])
58+
else
59+
error("EnsembleProblem with `trajectories` as second argument requires lower and upper bounds to be defined in the `OptimizationProblem`.")
60+
end
61+
return SciMLBase.EnsembleProblem(prob; prob_func, kwargs...)
62+
end
63+
4764
struct WeightedEnsembleProblem{T1 <: AbstractEnsembleProblem, T2 <: AbstractVector} <:
4865
AbstractEnsembleProblem
4966
ensembleprob::T1

test/downstream/ensemble_nondes.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,16 @@ sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleThrea
2323
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective
2424

2525
sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleDistributed(), trajectories = 5, maxiters = 5)
26-
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective
26+
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective
27+
28+
using NonlinearSolve
29+
30+
f(u, p) = u .* u .- p
31+
u0 = [1.0, 1.0]
32+
p = 2.0
33+
prob = NonlinearProblem(f, u0, p)
34+
ensembleprob = EnsembleProblem(prob, [u0, u0 .+ rand(2), u0 .+ rand(2), u0 .+ rand(2)])
35+
36+
sol = solve(ensembleprob, EnsembleThreads(), trajectories = 4, maxiters = 100)
37+
38+
sol = solve(ensembleprob, EnsembleDistributed(), trajectories = 4, maxiters = 100)

test/runtests.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,9 @@ end
7878
@time @safetestset "Ensemble solution statistics" begin
7979
include("downstream/ensemble_stats.jl")
8080
end
81-
# @time @safetestset "Ensemble Optimization and Nonlinear problems" begin
82-
# include("downstream/ensemble_nondes.jl")
83-
# end
81+
@time @safetestset "Ensemble Optimization and Nonlinear problems" begin
82+
include("downstream/ensemble_nondes.jl")
83+
end
8484
@time @safetestset "Ensemble with DifferentialEquations automatic algorithm selection" begin
8585
include("downstream/ensemble_diffeq.jl")
8686
end

0 commit comments

Comments
 (0)