Skip to content

Commit 4312353

Browse files
Merge pull request #341 from SciML/ap/fastlm_jac
Standardize the Extension Algorithms
2 parents fa16796 + bfaf014 commit 4312353

39 files changed

+736
-798
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ jobs:
1818
fail-fast: false
1919
matrix:
2020
group:
21-
- Core
22-
- NLLS
21+
- RootFinding
22+
- NLLSSolvers
2323
- 23TestProblems
2424
- Wrappers
2525
- Miscellaneous

Project.toml

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "3.3.0"
4+
version = "3.4.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -10,6 +10,7 @@ ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
1010
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
1111
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
1212
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
13+
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
1314
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1415
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1516
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
@@ -64,6 +65,7 @@ DiffEqBase = "6.144"
6465
EnumX = "1"
6566
Enzyme = "0.11.11"
6667
FastBroadcast = "0.2.8"
68+
FastClosures = "0.3"
6769
FastLevenbergMarquardt = "0.1"
6870
FiniteDiff = "2.21"
6971
FixedPointAcceleration = "0.3"
@@ -85,16 +87,17 @@ Printf = "1.9"
8587
Random = "1.91"
8688
RecursiveArrayTools = "3.2"
8789
Reexport = "1.2"
90+
SIAMFANLEquations = "1.0.1"
8891
SafeTestsets = "0.1"
8992
SciMLBase = "2.11"
9093
SciMLOperators = "0.3.7"
91-
SIAMFANLEquations = "1.0.1"
9294
SimpleNonlinearSolve = "1.0.2"
9395
SparseArrays = "1.9"
9496
SparseDiffTools = "2.14"
9597
SpeedMapping = "0.3"
9698
StableRNGs = "1"
9799
StaticArrays = "1.7"
100+
Sundials = "4.23.1"
98101
Symbolics = "5.13"
99102
Test = "1"
100103
UnPack = "1.0"
@@ -120,15 +123,16 @@ NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
120123
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
121124
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
122125
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
123-
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
124126
SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
127+
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
125128
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
126129
SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
127130
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
128131
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
132+
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
129133
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
130134
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
131135
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
132136

133137
[targets]
134-
test = ["Aqua", "Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "StableRNGs", "MINPACK", "NLsolve", "OrdinaryDiffEq", "SpeedMapping", "FixedPointAcceleration", "SIAMFANLEquations"]
138+
test = ["Aqua", "Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "StableRNGs", "MINPACK", "NLsolve", "OrdinaryDiffEq", "SpeedMapping", "FixedPointAcceleration", "SIAMFANLEquations", "Sundials"]

docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ makedocs(; sitename = "NonlinearSolve.jl",
1414
DiffEqBase, SciMLBase],
1515
clean = true, doctest = false, linkcheck = true,
1616
linkcheck_ignore = ["https://twitter.com/ChrisRackauckas/status/1544743542094020615"],
17-
warnonly = [:cross_references], checkdocs = :export,
17+
checkdocs = :export,
1818
format = Documenter.HTML(assets = ["assets/favicon.ico"],
1919
canonical = "https://docs.sciml.ai/NonlinearSolve/stable/"),
2020
pages)

docs/src/api/siamfanlequations.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SIAMFANLEquations.jl
22

3-
This is an extension for importing solvers from [SIAMFANLEquations.jl](https://github.com/ctkelley/SIAMFANLEquations.jl) into the SciML
3+
This is an extension for importing solvers from
4+
[SIAMFANLEquations.jl](https://github.com/ctkelley/SIAMFANLEquations.jl) into the SciML
45
interface. Note that these solvers do not come by default, and thus one needs to install
56
the package before using these solvers:
67

docs/src/basics/solve.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ solve(prob::SciMLBase.NonlinearProblem, args...; kwargs...)
1414
## Iteration Controls
1515

1616
- `maxiters::Int`: The maximum number of iterations to perform. Defaults to `1000`.
17-
- `abstol::Number`: The absolute tolerance.
18-
- `reltol::Number`: The relative tolerance.
17+
- `abstol::Number`: The absolute tolerance. Defaults to `real(oneunit(T)) * (eps(real(one(T))))^(4 // 5)`.
18+
- `reltol::Number`: The relative tolerance. Defaults to `real(oneunit(T)) * (eps(real(one(T))))^(4 // 5)`.
1919
- `termination_condition`: Termination Condition from DiffEqBase. Defaults to
2020
`AbsSafeBestTerminationMode()` for `NonlinearSolve.jl` and `AbsTerminateMode()` for
2121
`SimpleNonlinearSolve.jl`.

docs/src/solvers/FixedPointSolvers.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,10 @@ We are only listing the methods that natively solve fixed point problems.
4040

4141
- `FixedPointAccelerationJL()`: accelerates the convergence of a mapping to a fixed point
4242
by the Anderson acceleration algorithm and a few other methods.
43+
44+
### NLsolve.jl
45+
46+
In our tests, we have found the anderson method implemented here to NOT be the most
47+
robust.
48+
49+
- `NLsolveJL(; method = :anderson)`: Anderson acceleration for fixed point problems.

docs/src/solvers/NonlinearSystemSolvers.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,10 @@ Newton-Krylov form. However, KINSOL is known to be less stable than some other
143143
implementations, as it has no line search or globalizer (trust region).
144144

145145
- `KINSOL()`: The KINSOL method of the SUNDIALS C library
146+
147+
### SIAMFANLEquations.jl
148+
149+
SIAMFANLEquations.jl is a wrapper for the methods in the SIAMFANLEquations.jl library.
150+
151+
- `SIAMFANLEquationsJL()`: A wrapper for using the methods in
152+
[SIAMFANLEquations.jl](https://github.com/ctkelley/SIAMFANLEquations.jl)

ext/NonlinearSolveFastLevenbergMarquardtExt.jl

Lines changed: 19 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import ConcreteStructs: @concrete
55
import FastLevenbergMarquardt as FastLM
66
import FiniteDiff, ForwardDiff
77

8-
function _fast_lm_solver(::FastLevenbergMarquardtJL{linsolve}, x) where {linsolve}
8+
@inline function _fast_lm_solver(::FastLevenbergMarquardtJL{linsolve}, x) where {linsolve}
99
if linsolve === :cholesky
1010
return FastLM.CholeskySolver(ArrayInterface.undefmatrix(x))
1111
elseif linsolve === :qr
@@ -15,6 +15,7 @@ function _fast_lm_solver(::FastLevenbergMarquardtJL{linsolve}, x) where {linsolv
1515
end
1616
end
1717

18+
# TODO: Implement reinit
1819
@concrete struct FastLevenbergMarquardtJLCache
1920
f!
2021
J!
@@ -25,68 +26,27 @@ end
2526
kwargs
2627
end
2728

28-
@concrete struct InplaceFunction{iip} <: Function
29-
f
30-
end
31-
32-
(f::InplaceFunction{true})(fx, x, p) = f.f(fx, x, p)
33-
(f::InplaceFunction{false})(fx, x, p) = (fx .= f.f(x, p))
34-
3529
function SciMLBase.__init(prob::NonlinearLeastSquaresProblem,
36-
alg::FastLevenbergMarquardtJL, args...; alias_u0 = false, abstol = 1e-8,
37-
reltol = 1e-8, maxiters = 1000, kwargs...)
30+
alg::FastLevenbergMarquardtJL, args...; alias_u0 = false, abstol = nothing,
31+
reltol = nothing, maxiters = 1000, kwargs...)
32+
# FIXME: Support scalar u0
33+
prob.u0 isa Number &&
34+
throw(ArgumentError("FastLevenbergMarquardtJL does not support scalar `u0`"))
3835
iip = SciMLBase.isinplace(prob)
3936
u = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
4037
fu = NonlinearSolve.evaluate_f(prob, u)
4138

42-
f! = InplaceFunction{iip}(prob.f)
39+
f! = NonlinearSolve.__make_inplace{iip}(prob.f, nothing)
40+
41+
abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u))
42+
reltol = NonlinearSolve.DEFAULT_TOLERANCE(reltol, eltype(u))
4343

4444
if prob.f.jac === nothing
45-
use_forward_diff = if alg.autodiff === nothing
46-
ForwardDiff.can_dual(eltype(u))
47-
else
48-
alg.autodiff isa AutoForwardDiff
49-
end
50-
uf = SciMLBase.JacobianWrapper{iip}(prob.f, prob.p)
51-
if use_forward_diff
52-
cache = iip ? ForwardDiff.JacobianConfig(uf, fu, u) :
53-
ForwardDiff.JacobianConfig(uf, u)
54-
else
55-
cache = FiniteDiff.JacobianCache(u, fu)
56-
end
57-
J! = if iip
58-
if use_forward_diff
59-
fu_cache = similar(fu)
60-
function (J, x, p)
61-
uf.p = p
62-
ForwardDiff.jacobian!(J, uf, fu_cache, x, cache)
63-
return J
64-
end
65-
else
66-
function (J, x, p)
67-
uf.p = p
68-
FiniteDiff.finite_difference_jacobian!(J, uf, x, cache)
69-
return J
70-
end
71-
end
72-
else
73-
if use_forward_diff
74-
function (J, x, p)
75-
uf.p = p
76-
ForwardDiff.jacobian!(J, uf, x, cache)
77-
return J
78-
end
79-
else
80-
function (J, x, p)
81-
uf.p = p
82-
J_ = FiniteDiff.finite_difference_jacobian(uf, x, cache)
83-
copyto!(J, J_)
84-
return J
85-
end
86-
end
87-
end
45+
alg = NonlinearSolve.get_concrete_algorithm(alg, prob)
46+
J! = NonlinearSolve.__construct_jac(prob, alg, u;
47+
can_handle_arbitrary_dims = Val(true))
8848
else
89-
J! = InplaceFunction{iip}(prob.f.jac)
49+
J! = NonlinearSolve.__make_inplace{iip}(prob.f.jac, nothing)
9050
end
9151

9252
J = similar(u, length(fu), length(u))
@@ -95,17 +55,16 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem,
9555
LM = FastLM.LMWorkspace(u, fu, J)
9656

9757
return FastLevenbergMarquardtJLCache(f!, J!, prob, alg, LM, solver,
98-
(; xtol = abstol, ftol = reltol, maxit = maxiters, alg.factor, alg.factoraccept,
99-
alg.factorreject, alg.minscale, alg.maxscale, alg.factorupdate, alg.minfactor,
100-
alg.maxfactor, kwargs...))
58+
(; xtol = reltol, ftol = reltol, gtol = abstol, maxit = maxiters, alg.factor,
59+
alg.factoraccept, alg.factorreject, alg.minscale, alg.maxscale,
60+
alg.factorupdate, alg.minfactor, alg.maxfactor))
10161
end
10262

10363
function SciMLBase.solve!(cache::FastLevenbergMarquardtJLCache)
10464
res, fx, info, iter, nfev, njev, LM, solver = FastLM.lmsolve!(cache.f!, cache.J!,
10565
cache.lmworkspace, cache.prob.p; cache.solver, cache.kwargs...)
10666
stats = SciMLBase.NLStats(nfev, njev, -1, -1, iter)
107-
retcode = info == 1 ? ReturnCode.Success :
108-
(info == -1 ? ReturnCode.MaxIters : ReturnCode.Default)
67+
retcode = info == -1 ? ReturnCode.MaxIters : ReturnCode.Success
10968
return SciMLBase.build_solution(cache.prob, cache.alg, res, fx;
11069
retcode, original = (res, fx, info, iter, nfev, njev, LM, solver), stats)
11170
end

ext/NonlinearSolveFixedPointAccelerationExt.jl

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,43 +7,27 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::FixedPointAccelerationJL
77
show_trace::Val{PrintReports} = Val(false), termination_condition = nothing,
88
kwargs...) where {PrintReports}
99
@assert (termination_condition ===
10-
nothing)||(termination_condition isa AbsNormTerminationMode) "SpeedMappingJL does not support termination conditions!"
10+
nothing)||(termination_condition isa AbsNormTerminationMode) "FixedPointAccelerationJL does not support termination conditions!"
1111

12-
u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
13-
u_size = size(u0)
14-
T = eltype(u0)
15-
iip = isinplace(prob)
16-
p = prob.p
12+
f, u0 = NonlinearSolve.__construct_f(prob; alias_u0, make_fixed_point = Val(true),
13+
force_oop = Val(true))
1714

18-
if !iip && prob.u0 isa Number
19-
# FixedPointAcceleration makes the scalar problem into a vector problem
20-
f = (u) -> [prob.f(u[1], p) .+ u[1]]
21-
elseif !iip && prob.u0 isa AbstractVector
22-
f = (u) -> (prob.f(u, p) .+ u)
23-
elseif !iip && prob.u0 isa AbstractArray
24-
f = (u) -> vec(prob.f(reshape(u, u_size), p) .+ u)
25-
elseif iip && prob.u0 isa AbstractVector
26-
du = similar(u0)
27-
f = (u) -> (prob.f(du, u, p); du .+ u)
28-
else
29-
du = similar(u0)
30-
f = (u) -> (prob.f(du, reshape(u, u_size), p); vec(du) .+ u)
31-
end
32-
33-
tol = abstol === nothing ? real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) : abstol
15+
tol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u0))
3416

35-
sol = fixed_point(f, NonlinearSolve._vec(u0); Algorithm = alg.algorithm,
17+
sol = fixed_point(f, u0; Algorithm = alg.algorithm,
3618
ConvergenceMetricThreshold = tol, MaxIter = maxiters, MaxM = alg.m,
3719
ExtrapolationPeriod = alg.extrapolation_period, Dampening = alg.dampening,
3820
PrintReports, ReplaceInvalids = alg.replace_invalids,
3921
ConditionNumberThreshold = alg.condition_number_threshold, quiet_errors = true)
4022

41-
res = prob.u0 isa Number ? first(sol.FixedPoint_) : sol.FixedPoint_
42-
if res === missing
23+
if sol.FixedPoint_ === missing
24+
u0 = prob.u0 isa Number ? u0[1] : u0
4325
resid = NonlinearSolve.evaluate_f(prob, u0)
4426
res = u0
4527
converged = false
4628
else
29+
res = prob.u0 isa Number ? first(sol.FixedPoint_) :
30+
reshape(sol.FixedPoint_, size(prob.u0))
4731
resid = NonlinearSolve.evaluate_f(prob, res)
4832
converged = maximum(abs, resid) tol
4933
end

ext/NonlinearSolveLeastSquaresOptimExt.jl

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using NonlinearSolve, SciMLBase
44
import ConcreteStructs: @concrete
55
import LeastSquaresOptim as LSO
66

7-
function _lso_solver(::LeastSquaresOptimJL{alg, linsolve}) where {alg, linsolve}
7+
@inline function _lso_solver(::LeastSquaresOptimJL{alg, linsolve}) where {alg, linsolve}
88
ls = linsolve === :qr ? LSO.QR() :
99
(linsolve === :cholesky ? LSO.Cholesky() :
1010
(linsolve === :lsmr ? LSO.LSMR() : nothing))
@@ -17,41 +17,37 @@ function _lso_solver(::LeastSquaresOptimJL{alg, linsolve}) where {alg, linsolve}
1717
end
1818
end
1919

20+
# TODO: Implement reinit
2021
@concrete struct LeastSquaresOptimJLCache
2122
prob
2223
alg
2324
allocated_prob
2425
kwargs
2526
end
2627

27-
@concrete struct FunctionWrapper{iip}
28-
f
29-
p
30-
end
31-
32-
(f::FunctionWrapper{true})(du, u) = f.f(du, u, f.p)
33-
(f::FunctionWrapper{false})(du, u) = (du .= f.f(u, f.p))
34-
3528
function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, alg::LeastSquaresOptimJL,
36-
args...; alias_u0 = false, abstol = 1e-8, reltol = 1e-8, verbose = false,
37-
maxiters = 1000, kwargs...)
29+
args...; alias_u0 = false, abstol = nothing, show_trace::Val{ShT} = Val(false),
30+
trace_level = TraceMinimal(), store_trace::Val{StT} = Val(false), maxiters = 1000,
31+
reltol = nothing, kwargs...) where {ShT, StT}
3832
iip = SciMLBase.isinplace(prob)
39-
u = alias_u0 ? prob.u0 : deepcopy(prob.u0)
33+
u = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
34+
35+
abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u))
36+
reltol = NonlinearSolve.DEFAULT_TOLERANCE(reltol, eltype(u))
4037

41-
f! = FunctionWrapper{iip}(prob.f, prob.p)
42-
g! = prob.f.jac === nothing ? nothing : FunctionWrapper{iip}(prob.f.jac, prob.p)
38+
f! = NonlinearSolve.__make_inplace{iip}(prob.f, prob.p)
39+
g! = NonlinearSolve.__make_inplace{iip}(prob.f.jac, prob.p)
4340

4441
resid_prototype = prob.f.resid_prototype === nothing ?
45-
(!iip ? prob.f(u, prob.p) : zeros(u)) :
46-
prob.f.resid_prototype
42+
(!iip ? prob.f(u, prob.p) : zeros(u)) : prob.f.resid_prototype
4743

4844
lsoprob = LSO.LeastSquaresProblem(; x = u, f!, y = resid_prototype, g!,
4945
J = prob.f.jac_prototype, alg.autodiff, output_length = length(resid_prototype))
5046
allocated_prob = LSO.LeastSquaresProblemAllocated(lsoprob, _lso_solver(alg))
5147

5248
return LeastSquaresOptimJLCache(prob, alg, allocated_prob,
53-
(; x_tol = abstol, f_tol = reltol, iterations = maxiters, show_trace = verbose,
54-
kwargs...))
49+
(; x_tol = reltol, f_tol = abstol, g_tol = abstol, iterations = maxiters,
50+
show_trace = ShT, store_trace = StT, show_every = trace_level.print_frequency))
5551
end
5652

5753
function SciMLBase.solve!(cache::LeastSquaresOptimJLCache)

0 commit comments

Comments
 (0)