Skip to content

Commit 88890d5

Browse files
Merge pull request #538 from SciML/nllstoopt
Add constructor to convert NLLS to OptimizationProblem
2 parents ffe68ae + 5d3ab7a commit 88890d5

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

src/problems/basic_problems.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,21 @@ function OptimizationProblem(f, args...; kwargs...)
688688
OptimizationProblem{true}(OptimizationFunction{true}(f), args...; kwargs...)
689689
end
690690

691+
function OptimizationFunction(f::NonlinearFunction, adtype::AbstractADType = NoAD(); kwargs...)
692+
if isinplace(f)
693+
throw(ArgumentError("Converting NonlinearFunction to OptimizationFunction is not supported with in-place functions yet."))
694+
end
695+
OptimizationFunction((u, p) -> sum(abs2, f(u, p)), adtype; kwargs...)
696+
end
697+
698+
function OptimizationProblem(prob::NonlinearLeastSquaresProblem, adtype::AbstractADType = NoAD(); kwargs...)
699+
if isinplace(prob)
700+
throw(ArgumentError("Converting NonlinearLeastSquaresProblem to OptimizationProblem is not supported with in-place functions yet."))
701+
end
702+
optf = OptimizationFunction(prob.f, adtype; kwargs...)
703+
return OptimizationProblem(optf, prob.u0, prob.p; prob.kwargs..., kwargs...)
704+
end
705+
691706
isinplace(f::OptimizationFunction{iip}) where {iip} = iip
692707
isinplace(f::OptimizationProblem{iip}) where {iip} = iip
693708

test/downstream/nllsopt.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using NonlinearSolve, Optimization, OptimizationNLopt, ForwardDiff
2+
3+
true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4])
4+
5+
θ_true = [1.0, 0.1, 2.0, 0.5]
6+
7+
x = [-1.0, -0.5, 0.0, 0.5, 1.0]
8+
9+
y_target = true_function(x, θ_true)
10+
11+
function loss_function(θ, p)
12+
= true_function(p, θ)
13+
return.- y_target
14+
end
15+
16+
θ_init = θ_true .+ randn!(similar(θ_true)) * 0.1
17+
prob_oop = NonlinearLeastSquaresProblem{false}(loss_function, θ_init, x)
18+
19+
solver = LevenbergMarquardt()
20+
21+
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
22+
23+
optf = OptimizationFunction(prob_oop.f, AutoForwardDiff())
24+
optprob = OptimizationProblem(optf, prob_oop.u0, prob_oop.p)
25+
@time sol = solve(optprob, NLopt.LD_LBFGS(); maxiters = 10000, abstol = 1e-8)
26+
27+
optprob = OptimizationProblem(prob_oop, AutoForwardDiff())
28+
@time sol = solve(optprob, NLopt.LD_LBFGS(); maxiters = 10000, abstol = 1e-8)

0 commit comments

Comments
 (0)