@@ -90,18 +90,33 @@ from NLopt for an example. The common local optimizer arguments are:
9090  - `local_reltol`: relative tolerance  in changes of the objective value 
9191  - `local_options`: `NamedTuple` of keyword arguments for local optimizer 
9292""" 
93- function  solve (prob:: SciMLBase.OptimizationProblem , alg, args... ;
94-         kwargs... ):: SciMLBase.AbstractOptimizationSolution 
95-     if  SciMLBase. has_init (alg)
96-         solve! (init (prob, alg, args... ; kwargs... ))
93+ function  solve (prob:: SciMLBase.OptimizationProblem , args... ; sensealg =  nothing ,
94+         u0 =  nothing , p =  nothing , wrap =  Val (true ), kwargs... ):: SciMLBase.AbstractOptimizationSolution 
95+     if  sensealg ===  nothing  &&  haskey (prob. kwargs, :sensealg )
96+         sensealg =  prob. kwargs[:sensealg ]
97+     end 
98+ 
99+     u0 =  u0 != =  nothing  ?  u0 :  prob. u0
100+     p =  p != =  nothing  ?  p :  prob. p
101+ 
102+     if  wrap isa  Val{true }
103+         wrap_sol (solve_up (prob,
104+             sensealg,
105+             u0,
106+             p,
107+             args... ;
108+             originator =  SciMLBase. ChainRulesOriginator (),
109+             kwargs... ))
97110    else 
98-         if  prob. u0 != =  nothing  &&  ! isconcretetype (eltype (prob. u0))
99-             throw (SciMLBase. NonConcreteEltypeError (eltype (prob. u0)))
100-         end 
101-         _check_opt_alg (prob, alg; kwargs... )
102-         __solve (prob, alg, args... ; kwargs... )
111+         solve_up (prob,
112+             sensealg,
113+             u0,
114+             p,
115+             args... ;
116+             originator =  SciMLBase. ChainRulesOrginator (),
117+             kwargs... )
103118    end 
104- end 
119+ end   
105120
106121function  solve (
107122        prob:: SciMLBase.EnsembleProblem{T} , args... ; kwargs... ) where  {T < :
@@ -216,3 +231,101 @@ end
216231function  __solve (prob:: SciMLBase.OptimizationProblem , alg, args... ; kwargs... )
217232    throw (OptimizerMissingError (alg))
218233end 
234+ 
235+ function  solve_up (prob:: SciMLBase.OptimizationProblem , sensealg, u0, p, args... ; originator =  SciMLBase. ChainRulesOriginator (),
236+     kwargs... )
237+     alg =  extract_opt_alg (args, kwargs, has_kwargs (prob) ?  prob. kwargs :  kwargs)
238+     _prob =  get_concrete_problem (prob; u0 =  u0, p =  p, kwargs... )
239+     if  length (args) <  1 
240+         solve_call (_prob, alg, Base. tails (args)... , kwargs... )
241+     else 
242+         solve_call (_prob, alg; kwargs... )
243+     end 
244+ end 
245+ 
246+ function  solve_call (_prob, alg, args... ; merge_callbacks =  true , kwargshandle =  nothing ,
247+     kwargs... )
248+     kwargshandle =  kwargshandle ===  nothing  ?  KeywordArgError :  kwargshandle
249+     kwargshandle =  has_kwargs (_prob) &&  haskey (_prob. kwargs, :kwargshandle ) ? 
250+                    _prob. kwargs[:kwargshandle ] :  kwargshandle
251+ 
252+     if  has_kwargs (_prob)
253+         kwargs =  isempty (_prob. kwargs) ?  kwargs :  merge (values (_prob. kwargs), kwargs)
254+     end 
255+ 
256+     checkkwargs (kwargshandle; kwargs... )
257+ 
258+     if  SciMLBase. has_init (alg)
259+         solve! (init (_prob, alg, args... ; kwargs... ))
260+     else 
261+         if  _prob. u0 != =  nothing  &&  ! isconcretetype (eltype (_prob. u0))
262+             throw (SciMLBase. NonConcreteEltypeError (eltype (_prob. u0)))
263+         end 
264+         _check_opt_alg (prob, alg; kwargs... )
265+         __solve (_prob, alg, args... ; kwargs... )
266+     end 
267+ end 
268+ 
269+ function  get_concrete_problem (prob:: OptimizationProblem ; kwargs... )
270+     oldprob =  prob
271+     prob =  get_updated_symbolic_problem (get_root_indp (prob), prob; kwargs... )
272+     if  prob != =  oldprob
273+         kwargs =  (;kwargs... , u0 =  SII. state_values (prob), p =  SII. parameter_values (prob))
274+     end 
275+     p =  get_concrete_p (prob, kwargs)
276+     u0 =  get_concrete_u0 (prob, false , nothing , kwargs)
277+     u0 =  promote_u0 (u0, p, nothing )
278+     remake (prob; u0 =  u0, p =  p)
279+ 
280+ end 
281+ 
282+ 
283+ @inline  function  extract_opt_alg (solve_args, solve_kwargs, prob_kwargs)
284+     if  isempty (solve_args) ||  isnothing (first (solve_args))
285+         if  haskey (solve_kwargs, :alg )
286+             solve_kwargs[:alg ]
287+         elseif  haskey (prob_kwargs, :alg )
288+             prob_kwargs[:alg ]
289+         else 
290+             nothing 
291+         end 
292+     else 
293+         first (solve_args)
294+     end 
295+ end 
296+ 
297+ 
298+ function  _solve_forward (prob, sensealg, u0, p, originator, args... ; merge_callbacks =  true ,
299+     kwargs... )
300+     alg =  extract_opt_alg (args, kwargs, prob. kwargs)
301+     _prob =  get_concrete_problem (prob; u0 =  u0, p =  p, kwargs... )
302+ 
303+     if  has_kwargs (_prob)
304+         kwargs =  isempty (_porb. kwargs) ?  kwargs :  merge (values (_prob. kwargs), kwargs)
305+     end 
306+ 
307+     if  length (args) >  1 
308+         _concrete_solve_forward (_prob, alg, sensealg, u0, p, originator,
309+             Base. tail (args)... ; kwargs... )
310+     else 
311+         _concrete_solve_forward (_prob, alg, sensealg, u0, p, originator; kwargs... )
312+     end 
313+ end 
314+ 
315+ function  _solve_adjoint (_prob, sensealg, u0, p, originator, args... ; merge_callbacks =  true , 
316+     kwargs... )
317+     alg =  extract_alg (args, kwargs, prob. kwargs)
318+     
319+     _prob =  get_concrete_problem (prob; u0 =  u0, p =  p, kwargs... )
320+ 
321+     if  has_kwargs (_prob)
322+         kwargs =  isempty (_prob. kwargs) ?  kwargs :  merge (values (_prob. kwargs), kwargs)
323+     end 
324+     
325+     if  length (args) >  1 
326+         _concrete_solve_adjoint (_prob, alg, sensealg, u0, p, originator,
327+             Base. tail (args)... ; kwargs... )
328+     else 
329+         _concrete_solve_adjoint (_prob, alg, sensealg, u0, p, originator; kwargs... )
330+     end 
331+ end  
0 commit comments