@@ -4,7 +4,7 @@ using ArrayInterface, NonlinearSolve, SciMLBase
4
4
import ConcreteStructs: @concrete
5
5
import FastClosures: @closure
6
6
import FastLevenbergMarquardt as FastLM
7
- import StaticArraysCore: StaticArray
7
+ import StaticArraysCore: SArray
8
8
9
9
@inline function _fast_lm_solver (:: FastLevenbergMarquardtJL{linsolve} , x) where {linsolve}
10
10
if linsolve === :cholesky
@@ -15,53 +15,54 @@ import StaticArraysCore: StaticArray
15
15
throw (ArgumentError (" Unknown FastLevenbergMarquardt Linear Solver: $linsolve " ))
16
16
end
17
17
end
18
+ @inline _fast_lm_solver (:: FastLevenbergMarquardtJL{linsolve} , :: SArray ) where {linsolve} = linsolve
18
19
19
- # TODO : Implement reinit
20
- @concrete struct FastLevenbergMarquardtJLCache
21
- f!
22
- J!
23
- prob
24
- alg
25
- lmworkspace
26
- solver
27
- kwargs
28
- end
29
-
30
- function SciMLBase. __init (prob:: NonlinearLeastSquaresProblem ,
20
+ function SciMLBase. __solve (prob:: NonlinearLeastSquaresProblem ,
31
21
alg:: FastLevenbergMarquardtJL , args... ; alias_u0 = false , abstol = nothing ,
32
22
reltol = nothing , maxiters = 1000 , termination_condition = nothing , kwargs... )
33
23
NonlinearSolve. __test_termination_condition (termination_condition,
34
24
:FastLevenbergMarquardt )
35
- if prob. u0 isa StaticArray # FIXME
36
- error (" FastLevenbergMarquardtJL does not support StaticArrays yet." )
37
- end
38
25
39
- _f!, u, resid = NonlinearSolve. __construct_extension_f (prob; alias_u0)
40
- f! = @closure (du, u, p) -> _f! (du, u)
26
+ fn, u, resid = NonlinearSolve. __construct_extension_f (prob; alias_u0,
27
+ can_handle_oop = Val (prob. u0 isa SArray))
28
+ f = if prob. u0 isa SArray
29
+ @closure (u, p) -> fn (u)
30
+ else
31
+ @closure (du, u, p) -> fn (du, u)
32
+ end
41
33
abstol = NonlinearSolve. DEFAULT_TOLERANCE (abstol, eltype (u))
42
34
reltol = NonlinearSolve. DEFAULT_TOLERANCE (reltol, eltype (u))
43
35
44
- _J! = NonlinearSolve. __construct_extension_jac (prob, alg, u, resid; alg. autodiff)
45
- J! = @closure (J, u, p) -> _J! (J, u)
46
- J = prob. f. jac_prototype === nothing ? similar (u, length (resid), length (u)) :
47
- zero (prob. f. jac_prototype)
36
+ _jac_fn = NonlinearSolve. __construct_extension_jac (prob, alg, u, resid; alg. autodiff,
37
+ can_handle_oop = Val (prob. u0 isa SArray))
38
+ jac_fn = if prob. u0 isa SArray
39
+ @closure (u, p) -> _jac_fn (u)
40
+ else
41
+ @closure (J, u, p) -> _jac_fn (J, u)
42
+ end
48
43
49
- solver = _fast_lm_solver (alg, u)
50
- LM = FastLM. LMWorkspace (u, resid, J)
44
+ solver_kwargs = (; xtol = reltol, ftol = reltol, gtol = abstol, maxit = maxiters,
45
+ alg. factor, alg. factoraccept, alg. factorreject, alg. minscale, alg. maxscale,
46
+ alg. factorupdate, alg. minfactor, alg. maxfactor)
51
47
52
- return FastLevenbergMarquardtJLCache (f!, J!, prob, alg, LM, solver,
53
- (; xtol = reltol, ftol = reltol, gtol = abstol, maxit = maxiters, alg. factor,
54
- alg. factoraccept, alg. factorreject, alg. minscale, alg. maxscale,
55
- alg. factorupdate, alg. minfactor, alg. maxfactor))
56
- end
48
+ if prob. u0 isa SArray
49
+ res, fx, info, iter, nfev, njev = FastLM. lmsolve (f, jac_fn, prob. u0;
50
+ solver_kwargs... )
51
+ LM, solver = nothing , nothing
52
+ else
53
+ J = prob. f. jac_prototype === nothing ? similar (u, length (resid), length (u)) :
54
+ zero (prob. f. jac_prototype)
55
+ solver = _fast_lm_solver (alg, u)
56
+ LM = FastLM. LMWorkspace (u, resid, J)
57
+
58
+ res, fx, info, iter, nfev, njev, LM, solver = FastLM. lmsolve! (f, jac_fn, LM;
59
+ solver, solver_kwargs... )
60
+ end
57
61
58
- function SciMLBase. solve! (cache:: FastLevenbergMarquardtJLCache )
59
- res, fx, info, iter, nfev, njev, LM, solver = FastLM. lmsolve! (cache. f!, cache. J!,
60
- cache. lmworkspace; cache. solver, cache. kwargs... )
61
62
stats = SciMLBase. NLStats (nfev, njev, - 1 , - 1 , iter)
62
63
retcode = info == - 1 ? ReturnCode. MaxIters : ReturnCode. Success
63
- return SciMLBase. build_solution (cache . prob, cache . alg, res, fx;
64
- retcode, original = (res, fx, info, iter, nfev, njev, LM, solver), stats)
64
+ return SciMLBase. build_solution (prob, alg, res, fx; retcode,
65
+ original = (res, fx, info, iter, nfev, njev, LM, solver), stats)
65
66
end
66
67
67
68
end
0 commit comments