Skip to content

Commit 3515b98

Browse files
committed
use solve_up for SciMLSensitivity integration
1 parent e18cfe2 commit 3515b98

File tree

2 files changed

+127
-11
lines changed

2 files changed

+127
-11
lines changed

lib/OptimizationBase/src/OptimizationBase.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ import SciMLBase: solve, init, solve!, __init, __solve,
1313
allowsconstraints, requiresconstraints,
1414
allowscallback, requiresgradient,
1515
requireshessian, requiresconsjac,
16-
requiresconshess
16+
requiresconshess, wrap_sol, has_kwargs,
17+
get_root_indp, get_updated_symbolic_problem,
18+
get_concrete_p, get_concrete_u0, promote_u0,
19+
promote_p, KeywordArgError, checkkwargs
1720

1821
export ObjSense, MaxSense, MinSense
1922
export allowsbounds, requiresbounds, allowsconstraints, requiresconstraints,

lib/OptimizationBase/src/solve.jl

Lines changed: 123 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,18 +90,33 @@ from NLopt for an example. The common local optimizer arguments are:
9090
- `local_reltol`: relative tolerance in changes of the objective value
9191
- `local_options`: `NamedTuple` of keyword arguments for local optimizer
9292
"""
93-
function solve(prob::SciMLBase.OptimizationProblem, alg, args...;
94-
kwargs...)::SciMLBase.AbstractOptimizationSolution
95-
if SciMLBase.has_init(alg)
96-
solve!(init(prob, alg, args...; kwargs...))
93+
function solve(prob::SciMLBase.OptimizationProblem, args...; sensealg = nothing,
94+
u0 = nothing, p = nothing, wrap = Val(true), kwargs...)::SciMLBase.AbstractOptimizationSolution
95+
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
96+
sensealg = prob.kwargs[:sensealg]
97+
end
98+
99+
u0 = u0 !== nothing ? u0 : prob.u0
100+
p = p !== nothing ? p : prob.p
101+
102+
if wrap isa Val{true}
103+
wrap_sol(solve_up(prob,
104+
sensealg,
105+
u0,
106+
p,
107+
args...;
108+
originator = SciMLBase.ChainRulesOriginator(),
109+
kwargs...))
97110
else
98-
if prob.u0 !== nothing && !isconcretetype(eltype(prob.u0))
99-
throw(SciMLBase.NonConcreteEltypeError(eltype(prob.u0)))
100-
end
101-
_check_opt_alg(prob, alg; kwargs...)
102-
__solve(prob, alg, args...; kwargs...)
111+
solve_up(prob,
112+
sensealg,
113+
u0,
114+
p,
115+
args...;
116+
originator = SciMLBase.ChainRulesOrginator(),
117+
kwargs...)
103118
end
104-
end
119+
end
105120

106121
function solve(
107122
prob::SciMLBase.EnsembleProblem{T}, args...; kwargs...) where {T <:
@@ -216,3 +231,101 @@ end
216231
function __solve(prob::SciMLBase.OptimizationProblem, alg, args...; kwargs...)
217232
throw(OptimizerMissingError(alg))
218233
end
234+
235+
function solve_up(prob::SciMLBase.OptimizationProblem, sensealg, u0, p, args...; originator = SciMLBase.ChainRulesOriginator(),
236+
kwargs...)
237+
alg = extract_opt_alg(args, kwargs, has_kwargs(prob) ? prob.kwargs : kwargs)
238+
_prob = get_concrete_problem(prob; u0 = u0, p = p, kwargs...)
239+
if length(args) < 1
240+
solve_call(_prob, alg, Base.tails(args)..., kwargs...)
241+
else
242+
solve_call(_prob, alg; kwargs...)
243+
end
244+
end
245+
246+
function solve_call(_prob, alg, args...; merge_callbacks = true, kwargshandle = nothing,
247+
kwargs...)
248+
kwargshandle = kwargshandle === nothing ? KeywordArgError : kwargshandle
249+
kwargshandle = has_kwargs(_prob) && haskey(_prob.kwargs, :kwargshandle) ?
250+
_prob.kwargs[:kwargshandle] : kwargshandle
251+
252+
if has_kwargs(_prob)
253+
kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs)
254+
end
255+
256+
checkkwargs(kwargshandle; kwargs...)
257+
258+
if SciMLBase.has_init(alg)
259+
solve!(init(_prob, alg, args...; kwargs...))
260+
else
261+
if _prob.u0 !== nothing && !isconcretetype(eltype(_prob.u0))
262+
throw(SciMLBase.NonConcreteEltypeError(eltype(_prob.u0)))
263+
end
264+
_check_opt_alg(prob, alg; kwargs...)
265+
__solve(_prob, alg, args...; kwargs...)
266+
end
267+
end
268+
269+
function get_concrete_problem(prob::OptimizationProblem; kwargs...)
270+
oldprob = prob
271+
prob = get_updated_symbolic_problem(get_root_indp(prob), prob; kwargs...)
272+
if prob !== oldprob
273+
kwargs = (;kwargs..., u0 = SII.state_values(prob), p = SII.parameter_values(prob))
274+
end
275+
p = get_concrete_p(prob, kwargs)
276+
u0 = get_concrete_u0(prob, false, nothing, kwargs)
277+
u0 = promote_u0(u0, p, nothing)
278+
remake(prob; u0 = u0, p = p)
279+
280+
end
281+
282+
283+
@inline function extract_opt_alg(solve_args, solve_kwargs, prob_kwargs)
284+
if isempty(solve_args) || isnothing(first(solve_args))
285+
if haskey(solve_kwargs, :alg)
286+
solve_kwargs[:alg]
287+
elseif haskey(prob_kwargs, :alg)
288+
prob_kwargs[:alg]
289+
else
290+
nothing
291+
end
292+
else
293+
first(solve_args)
294+
end
295+
end
296+
297+
298+
function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callbacks = true,
299+
kwargs...)
300+
alg = extract_opt_alg(args, kwargs, prob.kwargs)
301+
_prob = get_concrete_problem(prob; u0 = u0, p = p, kwargs...)
302+
303+
if has_kwargs(_prob)
304+
kwargs = isempty(_porb.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs)
305+
end
306+
307+
if length(args) > 1
308+
_concrete_solve_forward(_prob, alg, sensealg, u0, p, originator,
309+
Base.tail(args)...; kwargs...)
310+
else
311+
_concrete_solve_forward(_prob, alg, sensealg, u0, p, originator; kwargs...)
312+
end
313+
end
314+
315+
function _solve_adjoint(_prob, sensealg, u0, p, originator, args...; merge_callbacks = true,
316+
kwargs...)
317+
alg = extract_alg(args, kwargs, prob.kwargs)
318+
319+
_prob = get_concrete_problem(prob; u0 = u0, p = p, kwargs...)
320+
321+
if has_kwargs(_prob)
322+
kwargs = isempty(_prob.kwargs) ? kwargs : merge(values(_prob.kwargs), kwargs)
323+
end
324+
325+
if length(args) > 1
326+
_concrete_solve_adjoint(_prob, alg, sensealg, u0, p, originator,
327+
Base.tail(args)...; kwargs...)
328+
else
329+
_concrete_solve_adjoint(_prob, alg, sensealg, u0, p, originator; kwargs...)
330+
end
331+
end

0 commit comments

Comments
 (0)