Skip to content

Commit 169d5c2

Browse files
Add OptimizationFunction conversion and adtype arg and tests
1 parent c838cf2 commit 169d5c2

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

src/problems/basic_problems.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -688,11 +688,18 @@ function OptimizationProblem(f, args...; kwargs...)
688688
OptimizationProblem{true}(OptimizationFunction{true}(f), args...; kwargs...)
689689
end
690690

691-
function OptimizationProblem(prob::NonlinearLeastSquaresProblem; kwargs...)
691+
function OptimizationFunction(f::NonlinearFunction, adtype::AbstractADType = NoAD(); kwargs...)
692+
if isinplace(f)
693+
throw(ArgumentError("Converting NonlinearFunction to OptimizationFunction is not supported with in-place functions yet."))
694+
end
695+
OptimizationFunction((u, p) -> sum(abs2, f(u, p)), adtype; kwargs...)
696+
end
697+
698+
function OptimizationProblem(prob::NonlinearLeastSquaresProblem, adtype::AbstractADType = NoAD(); kwargs...)
692699
if isinplace(prob)
693700
throw(ArgumentError("Converting NonlinearLeastSquaresProblem to OptimizationProblem is not supported with in-place functions yet."))
694701
end
695-
optf = OptimizationFunction(sum prob.f, grad = (Jv, u, p) -> prob.f.jvp(Jv, prob.f(u, p), u, p), kwargs...)
702+
optf = OptimizationFunction(prob.f, adtype; kwargs...)
696703
return OptimizationProblem(optf, prob.u0, prob.p; prob.kwargs..., kwargs...)
697704
end
698705

test/downstream/nllsopt.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
using NonlinearSolve, Optimization, OptimizationNLopt, ForwardDiff
2+
import FastLevenbergMarquardt, LeastSquaresOptim
3+
4+
true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4])
5+
6+
θ_true = [1.0, 0.1, 2.0, 0.5]
7+
8+
x = [-1.0, -0.5, 0.0, 0.5, 1.0]
9+
10+
y_target = true_function(x, θ_true)
11+
12+
function loss_function(θ, p)
13+
= true_function(p, θ)
14+
return.- y_target
15+
end
16+
17+
θ_init = θ_true .+ randn!(similar(θ_true)) * 0.1
18+
prob_oop = NonlinearLeastSquaresProblem{false}(loss_function, θ_init, x)
19+
20+
solver = LevenbergMarquardt()
21+
22+
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
23+
24+
optf = OptimizationFunction(prob_oop.f, AutoForwardDiff())
25+
optprob = OptimizationProblem(optf, prob_oop.u0, prob_oop.p)
26+
@time sol = solve(optprob, NLopt.LD_LBFGS(); maxiters = 10000, abstol = 1e-8)
27+
28+
optprob = OptimizationProblem(prob_oop, AutoForwardDiff())
29+
@time sol = solve(optprob, NLopt.LD_LBFGS(); maxiters = 10000, abstol = 1e-8)

0 commit comments

Comments
 (0)