Skip to content

Commit 1663669

Browse files
committed
Special case for static arrays in FastLM
1 parent acb4737 commit 1663669

File tree

2 files changed

+56
-38
lines changed

2 files changed

+56
-38
lines changed

ext/NonlinearSolveFastLevenbergMarquardtExt.jl

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using ArrayInterface, NonlinearSolve, SciMLBase
44
import ConcreteStructs: @concrete
55
import FastClosures: @closure
66
import FastLevenbergMarquardt as FastLM
7-
import StaticArraysCore: StaticArray
7+
import StaticArraysCore: SArray
88

99
@inline function _fast_lm_solver(::FastLevenbergMarquardtJL{linsolve}, x) where {linsolve}
1010
if linsolve === :cholesky
@@ -15,53 +15,54 @@ import StaticArraysCore: StaticArray
1515
throw(ArgumentError("Unknown FastLevenbergMarquardt Linear Solver: $linsolve"))
1616
end
1717
end
18+
@inline _fast_lm_solver(::FastLevenbergMarquardtJL{linsolve}, ::SArray) where {linsolve} = linsolve
1819

19-
# TODO: Implement reinit
20-
@concrete struct FastLevenbergMarquardtJLCache
21-
f!
22-
J!
23-
prob
24-
alg
25-
lmworkspace
26-
solver
27-
kwargs
28-
end
29-
30-
function SciMLBase.__init(prob::NonlinearLeastSquaresProblem,
20+
function SciMLBase.__solve(prob::NonlinearLeastSquaresProblem,
3121
alg::FastLevenbergMarquardtJL, args...; alias_u0 = false, abstol = nothing,
3222
reltol = nothing, maxiters = 1000, termination_condition = nothing, kwargs...)
3323
NonlinearSolve.__test_termination_condition(termination_condition,
3424
:FastLevenbergMarquardt)
35-
if prob.u0 isa StaticArray # FIXME
36-
error("FastLevenbergMarquardtJL does not support StaticArrays yet.")
37-
end
3825

39-
_f!, u, resid = NonlinearSolve.__construct_extension_f(prob; alias_u0)
40-
f! = @closure (du, u, p) -> _f!(du, u)
26+
fn, u, resid = NonlinearSolve.__construct_extension_f(prob; alias_u0,
27+
can_handle_oop = Val(prob.u0 isa SArray))
28+
f = if prob.u0 isa SArray
29+
@closure (u, p) -> fn(u)
30+
else
31+
@closure (du, u, p) -> fn(du, u)
32+
end
4133
abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u))
4234
reltol = NonlinearSolve.DEFAULT_TOLERANCE(reltol, eltype(u))
4335

44-
_J! = NonlinearSolve.__construct_extension_jac(prob, alg, u, resid; alg.autodiff)
45-
J! = @closure (J, u, p) -> _J!(J, u)
46-
J = prob.f.jac_prototype === nothing ? similar(u, length(resid), length(u)) :
47-
zero(prob.f.jac_prototype)
36+
_jac_fn = NonlinearSolve.__construct_extension_jac(prob, alg, u, resid; alg.autodiff,
37+
can_handle_oop = Val(prob.u0 isa SArray))
38+
jac_fn = if prob.u0 isa SArray
39+
@closure (u, p) -> _jac_fn(u)
40+
else
41+
@closure (J, u, p) -> _jac_fn(J, u)
42+
end
4843

49-
solver = _fast_lm_solver(alg, u)
50-
LM = FastLM.LMWorkspace(u, resid, J)
44+
solver_kwargs = (; xtol = reltol, ftol = reltol, gtol = abstol, maxit = maxiters,
45+
alg.factor, alg.factoraccept, alg.factorreject, alg.minscale, alg.maxscale,
46+
alg.factorupdate, alg.minfactor, alg.maxfactor)
5147

52-
return FastLevenbergMarquardtJLCache(f!, J!, prob, alg, LM, solver,
53-
(; xtol = reltol, ftol = reltol, gtol = abstol, maxit = maxiters, alg.factor,
54-
alg.factoraccept, alg.factorreject, alg.minscale, alg.maxscale,
55-
alg.factorupdate, alg.minfactor, alg.maxfactor))
56-
end
48+
if prob.u0 isa SArray
49+
res, fx, info, iter, nfev, njev = FastLM.lmsolve(f, jac_fn, prob.u0;
50+
solver_kwargs...)
51+
LM, solver = nothing, nothing
52+
else
53+
J = prob.f.jac_prototype === nothing ? similar(u, length(resid), length(u)) :
54+
zero(prob.f.jac_prototype)
55+
solver = _fast_lm_solver(alg, u)
56+
LM = FastLM.LMWorkspace(u, resid, J)
57+
58+
res, fx, info, iter, nfev, njev, LM, solver = FastLM.lmsolve!(f, jac_fn, LM;
59+
solver, solver_kwargs...)
60+
end
5761

58-
function SciMLBase.solve!(cache::FastLevenbergMarquardtJLCache)
59-
res, fx, info, iter, nfev, njev, LM, solver = FastLM.lmsolve!(cache.f!, cache.J!,
60-
cache.lmworkspace; cache.solver, cache.kwargs...)
6162
stats = SciMLBase.NLStats(nfev, njev, -1, -1, iter)
6263
retcode = info == -1 ? ReturnCode.MaxIters : ReturnCode.Success
63-
return SciMLBase.build_solution(cache.prob, cache.alg, res, fx;
64-
retcode, original = (res, fx, info, iter, nfev, njev, LM, solver), stats)
64+
return SciMLBase.build_solution(prob, alg, res, fx; retcode,
65+
original = (res, fx, info, iter, nfev, njev, LM, solver), stats)
6566
end
6667

6768
end

test/wrappers/nlls.jl

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using NonlinearSolve, LinearAlgebra, Test, StableRNGs, Random, ForwardDiff, Zygote
1+
using NonlinearSolve,
2+
LinearAlgebra, Test, StableRNGs, StaticArrays, Random, ForwardDiff, Zygote
23
import FastLevenbergMarquardt, LeastSquaresOptim, MINPACK
34

45
true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4])
@@ -8,7 +9,7 @@ true_function(y, x, θ) = (@. y = θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4]
89

910
x = [-1.0, -0.5, 0.0, 0.5, 1.0]
1011

11-
y_target = true_function(x, θ_true)
12+
const y_target = true_function(x, θ_true)
1213

1314
function loss_function(θ, p)
1415
= true_function(p, θ)
@@ -34,7 +35,7 @@ autodiff in (nothing, AutoForwardDiff(), AutoFiniteDiff(), :central, :forward)]
3435
for prob in nlls_problems, solver in solvers
3536
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
3637
@test SciMLBase.successful_retcode(sol)
37-
@test norm(sol.resid) < 1e-6
38+
@test norm(sol.resid, Inf) < 1e-6
3839
end
3940

4041
function jac!(J, θ, p)
@@ -76,5 +77,21 @@ append!(solvers, [CMINPACK(; method) for method in (:auto, :lm, :lmdif)])
7677

7778
for solver in solvers, prob in probs
7879
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
79-
@test norm(sol.resid) < 1e-6
80+
@test norm(sol.resid, Inf) < 1e-6
8081
end
82+
83+
# Static Arrays -- Fast Levenberg-Marquardt
84+
x_sa = SA[-1.0, -0.5, 0.0, 0.5, 1.0]
85+
86+
const y_target_sa = true_function(x_sa, θ_true)
87+
88+
function loss_function_sa(θ, p)
89+
= true_function(p, θ)
90+
return.- y_target_sa
91+
end
92+
93+
θ_init_sa = SVector{4}(θ_init)
94+
prob_sa = NonlinearLeastSquaresProblem{false}(loss_function_sa, θ_init_sa, x)
95+
96+
@time sol = solve(prob_sa, FastLevenbergMarquardtJL())
97+
@test norm(sol.resid, Inf) < 1e-6

0 commit comments

Comments
 (0)