Skip to content

Commit 1e27dd6

Browse files
committed
Standardize function construction and jacobians for FastLM
1 parent 2d19f54 commit 1e27dd6

16 files changed

+262
-228
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
1010
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
1111
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
1212
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
13+
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
1314
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1415
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1516
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
@@ -64,6 +65,7 @@ DiffEqBase = "6.144"
6465
EnumX = "1"
6566
Enzyme = "0.11.11"
6667
FastBroadcast = "0.2.8"
68+
FastClosures = "0.3"
6769
FastLevenbergMarquardt = "0.1"
6870
FiniteDiff = "2.21"
6971
FixedPointAcceleration = "0.3"
@@ -85,10 +87,10 @@ Printf = "1.9"
8587
Random = "1.91"
8688
RecursiveArrayTools = "3.2"
8789
Reexport = "1.2"
90+
SIAMFANLEquations = "1.0.1"
8891
SafeTestsets = "0.1"
8992
SciMLBase = "2.11"
9093
SciMLOperators = "0.3.7"
91-
SIAMFANLEquations = "1.0.1"
9294
SimpleNonlinearSolve = "1.0.2"
9395
SparseArrays = "1.9"
9496
SparseDiffTools = "2.14"
@@ -120,8 +122,8 @@ NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
120122
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
121123
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
122124
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
123-
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
124125
SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
126+
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
125127
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
126128
SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
127129
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ makedocs(; sitename = "NonlinearSolve.jl",
1414
DiffEqBase, SciMLBase],
1515
clean = true, doctest = false, linkcheck = true,
1616
linkcheck_ignore = ["https://twitter.com/ChrisRackauckas/status/1544743542094020615"],
17-
warnonly = [:cross_references], checkdocs = :export,
17+
checkdocs = :export,
1818
format = Documenter.HTML(assets = ["assets/favicon.ico"],
1919
canonical = "https://docs.sciml.ai/NonlinearSolve/stable/"),
2020
pages)

ext/NonlinearSolveFastLevenbergMarquardtExt.jl

Lines changed: 9 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -25,71 +25,27 @@ end
2525
kwargs
2626
end
2727

28-
@concrete struct InplaceFunction{iip} <: Function
29-
f
30-
end
31-
32-
(f::InplaceFunction{true})(fx, x, p) = f.f(fx, x, p)
33-
(f::InplaceFunction{false})(fx, x, p) = (fx .= f.f(x, p))
34-
3528
function SciMLBase.__init(prob::NonlinearLeastSquaresProblem,
3629
alg::FastLevenbergMarquardtJL, args...; alias_u0 = false, abstol = nothing,
3730
reltol = nothing, maxiters = 1000, kwargs...)
31+
# FIXME: Support scalar u0
32+
prob.u0 isa Number &&
33+
throw(ArgumentError("FastLevenbergMarquardtJL does not support scalar `u0`"))
3834
iip = SciMLBase.isinplace(prob)
3935
u = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
4036
fu = NonlinearSolve.evaluate_f(prob, u)
4137

42-
f! = InplaceFunction{iip}(prob.f)
38+
f! = NonlinearSolve.__make_inplace{iip}(prob.f, nothing)
4339

4440
abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u))
4541
reltol = NonlinearSolve.DEFAULT_TOLERANCE(reltol, eltype(u))
4642

4743
if prob.f.jac === nothing
48-
use_forward_diff = if alg.autodiff === nothing
49-
ForwardDiff.can_dual(eltype(u))
50-
else
51-
alg.autodiff isa AutoForwardDiff
52-
end
53-
uf = SciMLBase.JacobianWrapper{iip}(prob.f, prob.p)
54-
if use_forward_diff
55-
cache = iip ? ForwardDiff.JacobianConfig(uf, fu, u) :
56-
ForwardDiff.JacobianConfig(uf, u)
57-
else
58-
cache = FiniteDiff.JacobianCache(u, fu)
59-
end
60-
J! = if iip
61-
if use_forward_diff
62-
fu_cache = similar(fu)
63-
function (J, x, p)
64-
uf.p = p
65-
ForwardDiff.jacobian!(J, uf, fu_cache, x, cache)
66-
return J
67-
end
68-
else
69-
function (J, x, p)
70-
uf.p = p
71-
FiniteDiff.finite_difference_jacobian!(J, uf, x, cache)
72-
return J
73-
end
74-
end
75-
else
76-
if use_forward_diff
77-
function (J, x, p)
78-
uf.p = p
79-
ForwardDiff.jacobian!(J, uf, x, cache)
80-
return J
81-
end
82-
else
83-
function (J, x, p)
84-
uf.p = p
85-
J_ = FiniteDiff.finite_difference_jacobian(uf, x, cache)
86-
copyto!(J, J_)
87-
return J
88-
end
89-
end
90-
end
44+
alg = NonlinearSolve.get_concrete_algorithm(alg, prob)
45+
J! = NonlinearSolve.__construct_jac(prob, alg, u;
46+
can_handle_arbitrary_dims = Val(true))
9147
else
92-
J! = InplaceFunction{iip}(prob.f.jac)
48+
J! = NonlinearSolve.__make_inplace{iip}(prob.f.jac, nothing)
9349
end
9450

9551
J = similar(u, length(fu), length(u))
@@ -107,8 +63,7 @@ function SciMLBase.solve!(cache::FastLevenbergMarquardtJLCache)
10763
res, fx, info, iter, nfev, njev, LM, solver = FastLM.lmsolve!(cache.f!, cache.J!,
10864
cache.lmworkspace, cache.prob.p; cache.solver, cache.kwargs...)
10965
stats = SciMLBase.NLStats(nfev, njev, -1, -1, iter)
110-
retcode = info == 1 ? ReturnCode.Success :
111-
(info == -1 ? ReturnCode.MaxIters : ReturnCode.Default)
66+
retcode = info == -1 ? ReturnCode.MaxIters : ReturnCode.Success
11267
return SciMLBase.build_solution(cache.prob, cache.alg, res, fx;
11368
retcode, original = (res, fx, info, iter, nfev, njev, LM, solver), stats)
11469
end

ext/NonlinearSolveFixedPointAccelerationExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::FixedPointAccelerationJL
77
show_trace::Val{PrintReports} = Val(false), termination_condition = nothing,
88
kwargs...) where {PrintReports}
99
@assert (termination_condition ===
10-
nothing)||(termination_condition isa AbsNormTerminationMode) "SpeedMappingJL does not support termination conditions!"
10+
nothing)||(termination_condition isa AbsNormTerminationMode) "FixedPointAccelerationJL does not support termination conditions!"
1111

1212
u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
1313
u_size = size(u0)

ext/NonlinearSolveLeastSquaresOptimExt.jl

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,14 @@ import LeastSquaresOptim as LSO
1717
end
1818
end
1919

20+
# TODO: Implement reinit
2021
@concrete struct LeastSquaresOptimJLCache
2122
prob
2223
alg
2324
allocated_prob
2425
kwargs
2526
end
2627

27-
@concrete struct FunctionWrapper{iip}
28-
f
29-
p
30-
end
31-
32-
(f::FunctionWrapper{true})(du, u) = f.f(du, u, f.p)
33-
(f::FunctionWrapper{false})(du, u) = (du .= f.f(u, f.p))
34-
3528
function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, alg::LeastSquaresOptimJL,
3629
args...; alias_u0 = false, abstol = nothing, show_trace::Val{ShT} = Val(false),
3730
trace_level = TraceMinimal(), store_trace::Val{StT} = Val(false), maxiters = 1000,
@@ -42,8 +35,8 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, alg::LeastSquaresO
4235
abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u))
4336
reltol = NonlinearSolve.DEFAULT_TOLERANCE(reltol, eltype(u))
4437

45-
f! = FunctionWrapper{iip}(prob.f, prob.p)
46-
g! = prob.f.jac === nothing ? nothing : FunctionWrapper{iip}(prob.f.jac, prob.p)
38+
f! = NonlinearSolve.__make_inplace{iip}(prob.f, prob.p)
39+
g! = NonlinearSolve.__make_inplace{iip}(prob.f.jac, prob.p)
4740

4841
resid_prototype = prob.f.resid_prototype === nothing ?
4942
(!iip ? prob.f(u, prob.p) : zeros(u)) : prob.f.resid_prototype

ext/NonlinearSolveMINPACKExt.jl

Lines changed: 23 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module NonlinearSolveMINPACKExt
22

33
using NonlinearSolve, DiffEqBase, SciMLBase
44
using MINPACK
5+
import FastClosures: @closure
56

67
function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
78
NonlinearLeastSquaresProblem{uType, iip}}, alg::CMINPACK, args...;
@@ -11,80 +12,42 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
1112
@assert (termination_condition ===
1213
nothing)||(termination_condition isa AbsNormTerminationMode) "CMINPACK does not support termination conditions!"
1314

14-
if prob.u0 isa Number
15-
u0 = [prob.u0]
16-
else
17-
u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
18-
end
15+
f!_, u0 = NonlinearSolve.__construct_f(prob; alias_u0)
16+
f! = @closure (du, u) -> (f!_(du, u); Cint(0))
1917

20-
sizeu = size(prob.u0)
21-
p = prob.p
22-
23-
# unwrapping alg params
24-
show_trace = alg.show_trace || ShT
25-
tracing = alg.tracing || StT
26-
27-
if !iip && prob.u0 isa Number
28-
f! = (du, u) -> (du .= prob.f(first(u), p); Cint(0))
29-
elseif !iip && prob.u0 isa AbstractVector
30-
f! = (du, u) -> (du .= prob.f(u, p); Cint(0))
31-
elseif !iip && prob.u0 isa AbstractArray
32-
f! = (du, u) -> (du .= vec(prob.f(reshape(u, sizeu), p)); Cint(0))
33-
elseif prob.u0 isa AbstractVector
34-
f! = (du, u) -> prob.f(du, u, p)
35-
else # Then it's an in-place function on an abstract array
36-
f! = (du, u) -> (prob.f(reshape(du, sizeu), reshape(u, sizeu), p); du = vec(du); 0)
37-
end
38-
39-
u = zero(u0)
40-
resid = NonlinearSolve.evaluate_f(prob, u)
18+
resid = NonlinearSolve.evaluate_f(prob, prob.u0)
4119
m = length(resid)
42-
size_jac = (length(resid), length(u))
4320

4421
method = ifelse(alg.method === :auto,
4522
ifelse(prob isa NonlinearLeastSquaresProblem, :lm, :hybr), alg.method)
4623

47-
abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u))
24+
show_trace = alg.show_trace || ShT
25+
tracing = alg.tracing || StT
26+
tol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u0))
27+
28+
jac!_ = NonlinearSolve.__construct_jac(prob, alg, u0)
4829

49-
if SciMLBase.has_jac(prob.f)
50-
if !iip && prob.u0 isa Number
51-
g! = (du, u) -> (du .= prob.f.jac(first(u), p); Cint(0))
52-
elseif !iip && prob.u0 isa AbstractVector
53-
g! = (du, u) -> (du .= prob.f.jac(u, p); Cint(0))
54-
elseif !iip && prob.u0 isa AbstractArray
55-
g! = (du, u) -> (du .= vec(prob.f.jac(reshape(u, sizeu), p)); Cint(0))
56-
elseif prob.u0 isa AbstractVector
57-
g! = (du, u) -> prob.f.jac(du, u, p)
58-
else # Then it's an in-place function on an abstract array
59-
g! = function (du, u)
60-
prob.f.jac(reshape(du, size_jac), reshape(u, sizeu), p)
61-
return Cint(0)
62-
end
63-
end
64-
original = MINPACK.fsolve(f!, g!, vec(u0), m; tol = abstol, show_trace, tracing,
65-
method, iterations = maxiters)
30+
if jac!_ === nothing
31+
original = MINPACK.fsolve(f!, u0, m; tol, show_trace, tracing, method,
32+
iterations = maxiters)
6633
else
67-
original = MINPACK.fsolve(f!, vec(u0), m; tol = abstol, show_trace, tracing,
68-
method, iterations = maxiters)
34+
jac! = @closure((J, u) -> (jac!_(J, u); Cint(0)))
35+
original = MINPACK.fsolve(f!, jac!, u0, m; tol, show_trace, tracing, method,
36+
iterations = maxiters)
6937
end
7038

71-
u = reshape(original.x, size(u))
72-
resid = original.f
73-
# retcode = original.converged ? ReturnCode.Success : ReturnCode.Failure
74-
# MINPACK lies about convergence? or maybe uses some other criteria?
75-
# We just check for absolute tolerance on the residual
76-
objective = maximum(abs, resid)
77-
retcode = ifelse(objective abstol, ReturnCode.Success, ReturnCode.Failure)
39+
u = original.x
40+
resid_ = original.f
41+
objective = maximum(abs, resid_)
42+
retcode = ifelse(objective tol, ReturnCode.Success, ReturnCode.Failure)
7843

79-
# These are only meaningful if `tracing = true`
44+
# These are only meaningful if `store_trace = Val(true)`
8045
stats = SciMLBase.NLStats(original.trace.f_calls, original.trace.g_calls,
8146
original.trace.g_calls, original.trace.g_calls, -1)
8247

83-
if prob.u0 isa Number
84-
return SciMLBase.build_solution(prob, alg, u[1], resid[1]; stats, retcode, original)
85-
else
86-
return SciMLBase.build_solution(prob, alg, u, resid; stats, retcode, original)
87-
end
48+
u_ = prob.u0 isa Number ? original.x[1] : reshape(original.x, size(prob.u0))
49+
resid_ = prob.u0 isa Number ? resid_[1] : reshape(resid_, size(resid))
50+
return SciMLBase.build_solution(prob, alg, u_, resid_; retcode, original, stats)
8851
end
8952

9053
end

ext/NonlinearSolveNLsolveExt.jl

Lines changed: 15 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,36 @@
11
module NonlinearSolveNLsolveExt
22

33
using NonlinearSolve, NLsolve, DiffEqBase, SciMLBase
4-
import UnPack: @unpack
54

65
function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...;
76
abstol = nothing, maxiters = 1000, alias_u0::Bool = false,
87
termination_condition = nothing, kwargs...)
98
@assert (termination_condition ===
109
nothing)||(termination_condition isa AbsNormTerminationMode) "NLsolveJL does not support termination conditions!"
1110

12-
if prob.u0 isa Number
13-
u0 = [prob.u0]
14-
else
15-
u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
16-
end
17-
18-
iip = isinplace(prob)
19-
sizeu = size(prob.u0)
20-
p = prob.p
11+
f!, u0 = NonlinearSolve.__construct_f(prob; alias_u0)
2112

2213
# unwrapping alg params
23-
@unpack method, autodiff, store_trace, extended_trace, linesearch, linsolve = alg
24-
@unpack factor, autoscale, m, beta, show_trace = alg
25-
26-
if !iip && prob.u0 isa Number
27-
f! = (du, u) -> (du .= prob.f(first(u), p); Cint(0))
28-
elseif !iip && prob.u0 isa AbstractVector
29-
f! = (du, u) -> (du .= prob.f(u, p); Cint(0))
30-
elseif !iip && prob.u0 isa AbstractArray
31-
f! = (du, u) -> (du .= vec(prob.f(reshape(u, sizeu), p)); Cint(0))
32-
elseif prob.u0 isa AbstractVector
33-
f! = (du, u) -> prob.f(du, u, p)
34-
else # Then it's an in-place function on an abstract array
35-
f! = (du, u) -> (prob.f(reshape(du, sizeu), reshape(u, sizeu), p); du = vec(du); 0)
36-
end
14+
(; method, autodiff, store_trace, extended_trace, linesearch, linsolve, factor,
15+
autoscale, m, beta, show_trace) = alg
3716

3817
if prob.u0 isa Number
3918
resid = [NonlinearSolve.evaluate_f(prob, first(u0))]
4019
else
4120
resid = NonlinearSolve.evaluate_f(prob, u0)
4221
end
4322

44-
size_jac = (length(resid), length(u0))
23+
jac! = NonlinearSolve.__construct_jac(prob, alg, u0)
4524

46-
if SciMLBase.has_jac(prob.f)
47-
if !iip && prob.u0 isa Number
48-
g! = (du, u) -> (du .= prob.f.jac(first(u), p); Cint(0))
49-
elseif !iip && prob.u0 isa AbstractVector
50-
g! = (du, u) -> (du .= prob.f.jac(u, p); Cint(0))
51-
elseif !iip && prob.u0 isa AbstractArray
52-
g! = (du, u) -> (du .= vec(prob.f.jac(reshape(u, sizeu), p)); Cint(0))
53-
elseif prob.u0 isa AbstractVector
54-
g! = (du, u) -> prob.f.jac(du, u, p)
55-
else # Then it's an in-place function on an abstract array
56-
g! = function (du, u)
57-
prob.f.jac(reshape(du, size_jac), reshape(u, sizeu), p)
58-
return Cint(0)
59-
end
60-
end
25+
if jac! === nothing
26+
df = OnceDifferentiable(f!, vec(u0), vec(resid); autodiff)
27+
else
6128
if prob.f.jac_prototype !== nothing
6229
J = zero(prob.f.jac_prototype)
63-
df = OnceDifferentiable(f!, g!, vec(u0), vec(resid), J)
30+
df = OnceDifferentiable(f!, jac!, vec(u0), vec(resid), J)
6431
else
65-
df = OnceDifferentiable(f!, g!, vec(u0), vec(resid))
32+
df = OnceDifferentiable(f!, jac!, vec(u0), vec(resid))
6633
end
67-
else
68-
df = OnceDifferentiable(f!, vec(u0), vec(resid); autodiff)
6934
end
7035

7136
abstol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u0))
@@ -74,17 +39,16 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::NLsolveJL, args...;
7439
store_trace, extended_trace, linesearch, linsolve, factor, autoscale, m, beta,
7540
show_trace)
7641

77-
u = reshape(original.zero, size(u0))
78-
f!(vec(resid), vec(u))
42+
f!(vec(resid), original.zero)
43+
u = prob.u0 isa Number ? original.zero[1] : reshape(original.zero, size(prob.u0))
44+
resid = prob.u0 isa Number ? resid[1] : resid
45+
7946
retcode = original.x_converged || original.f_converged ? ReturnCode.Success :
8047
ReturnCode.Failure
8148
stats = SciMLBase.NLStats(original.f_calls, original.g_calls, original.g_calls,
8249
original.g_calls, original.iterations)
83-
if prob.u0 isa Number
84-
return SciMLBase.build_solution(prob, alg, u[1], resid[1]; retcode, original, stats)
85-
else
86-
return SciMLBase.build_solution(prob, alg, u, resid; retcode, original, stats)
87-
end
50+
51+
return SciMLBase.build_solution(prob, alg, u, resid; retcode, original, stats)
8852
end
8953

9054
end

0 commit comments

Comments
 (0)