Skip to content

Commit 06d5c2c

Browse files
Merge pull request #534 from SciML/optensemble
Relax type of alg in ensemble solve for optimization
2 parents 88890d5 + ebcbae3 commit 06d5c2c

File tree

9 files changed

+82
-3
lines changed

9 files changed

+82
-3
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
2020
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2121
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
2222
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
23+
QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b"
2324
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
2425
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2526
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -78,6 +79,7 @@ Statistics = "1"
7879
SymbolicIndexingInterface = "0.2"
7980
Tables = "1"
8081
TruncatedStacktraces = "1"
82+
QuasiMonteCarlo = "0.3"
8183
Zygote = "0.6"
8284
julia = "1.9"
8385

src/SciMLBase.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import EnumX
2323
import TruncatedStacktraces
2424
import ADTypes: AbstractADType
2525
import FillArrays
26-
26+
import QuasiMonteCarlo
2727
using Reexport
2828
using SciMLOperators
2929
using SciMLOperators:

src/ensemble/basic_ensemble_solve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ function __solve(prob::EnsembleProblem{<:AbstractVector{<:AbstractSciMLProblem}}
5757
end
5858

5959
function __solve(prob::AbstractEnsembleProblem,
60-
alg::Union{AbstractDEAlgorithm, Nothing},
60+
alg::A,
6161
ensemblealg::BasicEnsembleAlgorithm;
6262
trajectories, batch_size = trajectories,
63-
pmap_batch_size = batch_size ÷ 100 > 0 ? batch_size ÷ 100 : 1, kwargs...)
63+
pmap_batch_size = batch_size ÷ 100 > 0 ? batch_size ÷ 100 : 1, kwargs...) where {A}
6464
num_batches = trajectories ÷ batch_size
6565
num_batches < 1 &&
6666
error("trajectories ÷ batch_size cannot be less than 1, got $num_batches")

src/ensemble/ensemble_problems.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,23 @@ function EnsembleProblem(; prob,
4444
EnsembleProblem(prob; prob_func, output_func, reduction, u_init, safetycopy)
4545
end
4646

47+
#since NonlinearProblem might want to use this dispatch as well
48+
function SciMLBase.EnsembleProblem(prob::AbstractSciMLProblem, u0s::Vector{Vector{T}}; kwargs...) where {T}
49+
prob_func = (prob, i, repeat = nothing) -> remake(prob, u0 = u0s[i])
50+
return SciMLBase.EnsembleProblem(prob; prob_func, kwargs...)
51+
end
52+
53+
#only makes sense for OptimizationProblem, might make sense for IntervalNonlinearProblem
54+
function SciMLBase.EnsembleProblem(prob::OptimizationProblem, trajectories::Int; kwargs...)
55+
if prob.lb !== nothing && prob.ub !== nothing
56+
u0s = QuasiMonteCarlo.sample(trajectories, prob.lb, prob.ub, QuasiMonteCarlo.LatinHypercubeSample())
57+
prob_func = (prob, i, repeat = nothing) -> remake(prob, u0 = u0s[:, i])
58+
else
59+
error("EnsembleProblem with `trajectories` as second argument requires lower and upper bounds to be defined in the `OptimizationProblem`.")
60+
end
61+
return SciMLBase.EnsembleProblem(prob; prob_func, kwargs...)
62+
end
63+
4764
struct WeightedEnsembleProblem{T1 <: AbstractEnsembleProblem, T2 <: AbstractVector} <:
4865
AbstractEnsembleProblem
4966
ensembleprob::T1

src/solve.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ function solve(prob::OptimizationProblem, alg, args...;
9898
end
9999
end
100100

101+
function SciMLBase.solve(prob::EnsembleProblem{T}, args...; kwargs...) where {T <: OptimizationProblem}
102+
return SciMLBase.__solve(prob, args...; kwargs...)
103+
end
104+
101105
function _check_opt_alg(prob::OptimizationProblem, alg; kwargs...)
102106
!allowsbounds(alg) && (!isnothing(prob.lb) || !isnothing(prob.ub)) &&
103107
throw(IncompatibleOptimizerError("The algorithm $(typeof(alg)) does not support box constraints. Either remove the `lb` or `ub` bounds passed to `OptimizationProblem` or use a different algorithm."))

test/downstream/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
[deps]
22
BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d"
3+
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
4+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
35
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
6+
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
47
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
58
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
69
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"

test/downstream/ensemble_diffeq.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
using DifferentialEquations
2+
3+
prob = ODEProblem((u, p, t) -> 1.01u, 0.5, (0.0, 1.0))
4+
function prob_func(prob, i, repeat)
5+
remake(prob, u0 = rand() * prob.u0)
6+
end
7+
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
8+
sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 10)
9+
@test sim isa EnsembleSolution

test/downstream/ensemble_nondes.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
using Optimization, OptimizationOptimJL, ForwardDiff, Test
2+
3+
x0 = zeros(2)
4+
rosenbrock(x, p = nothing) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2
5+
l1 = rosenbrock(x0)
6+
7+
optf = OptimizationFunction(rosenbrock, Optimization.AutoForwardDiff())
8+
prob = OptimizationProblem(optf, x0)
9+
sol1 = Optimization.solve(prob, OptimizationOptimJL.BFGS(), maxiters = 5)
10+
11+
ensembleprob = Optimization.EnsembleProblem(prob, [x0, x0 .+ rand(2), x0 .+ rand(2), x0 .+ rand(2)])
12+
13+
sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleThreads(), trajectories = 4, maxiters = 5)
14+
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective
15+
16+
sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleDistributed(), trajectories = 4, maxiters = 5)
17+
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective
18+
19+
prob = OptimizationProblem(optf, x0, lb = [-0.5, -0.5], ub = [0.5, 0.5])
20+
ensembleprob = Optimization.EnsembleProblem(prob, 5, prob_func = (prob, i, repeat) -> remake(prob, u0 = rand(-0.5:0.001:0.5, 2)))
21+
22+
sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleThreads(), trajectories = 5, maxiters = 5)
23+
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective
24+
25+
sol = Optimization.solve(ensembleprob, OptimizationOptimJL.BFGS(), EnsembleDistributed(), trajectories = 5, maxiters = 5)
26+
@test findmin(i -> sol[i].objective, 1:4)[1] < sol1.objective
27+
28+
using NonlinearSolve
29+
30+
f(u, p) = u .* u .- p
31+
u0 = [1.0, 1.0]
32+
p = 2.0
33+
prob = NonlinearProblem(f, u0, p)
34+
ensembleprob = EnsembleProblem(prob, [u0, u0 .+ rand(2), u0 .+ rand(2), u0 .+ rand(2)])
35+
36+
sol = solve(ensembleprob, EnsembleThreads(), trajectories = 4, maxiters = 100)
37+
38+
sol = solve(ensembleprob, EnsembleDistributed(), trajectories = 4, maxiters = 100)

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,12 @@ end
7878
@time @safetestset "Ensemble solution statistics" begin
7979
include("downstream/ensemble_stats.jl")
8080
end
81+
@time @safetestset "Ensemble Optimization and Nonlinear problems" begin
82+
include("downstream/ensemble_nondes.jl")
83+
end
84+
@time @safetestset "Ensemble with DifferentialEquations automatic algorithm selection" begin
85+
include("downstream/ensemble_diffeq.jl")
86+
end
8187
@time @safetestset "Symbol and integer based indexing of interpolated solutions" begin
8288
include("downstream/symbol_indexing.jl")
8389
end

0 commit comments

Comments
 (0)