Skip to content

Commit 52b3832

Browse files
Merge pull request #385 from SciML/default_autodiff
Make default polyalgs respect autodiff
2 parents 7897f20 + 30697ee commit 52b3832

File tree

3 files changed

+51
-20
lines changed

3 files changed

+51
-20
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "3.7.2"
4+
version = "3.7.3"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/default.jl

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -307,12 +307,13 @@ function FastShortcutNonlinearPolyalg(
307307
# and thus are not included in the polyalgorithm
308308
if SA
309309
if __is_complex(T)
310-
algs = (SimpleBroyden(), Broyden(; init_jacobian = Val(:true_jacobian)),
310+
algs = (SimpleBroyden(),
311+
Broyden(; init_jacobian = Val(:true_jacobian), autodiff),
311312
SimpleKlement(),
312313
NewtonRaphson(; concrete_jac, linsolve, precs, autodiff))
313314
else
314315
algs = (SimpleBroyden(),
315-
Broyden(; init_jacobian = Val(:true_jacobian)),
316+
Broyden(; init_jacobian = Val(:true_jacobian), autodiff),
316317
SimpleKlement(),
317318
NewtonRaphson(; concrete_jac, linsolve, precs, autodiff),
318319
NewtonRaphson(; concrete_jac, linsolve, precs,
@@ -322,13 +323,13 @@ function FastShortcutNonlinearPolyalg(
322323
end
323324
else
324325
if __is_complex(T)
325-
algs = (Broyden(), Broyden(; init_jacobian = Val(:true_jacobian)),
326-
Klement(; linsolve, precs),
326+
algs = (Broyden(), Broyden(; init_jacobian = Val(:true_jacobian), autodiff),
327+
Klement(; linsolve, precs, autodiff),
327328
NewtonRaphson(; concrete_jac, linsolve, precs, autodiff))
328329
else
329-
algs = (Broyden(),
330-
Broyden(; init_jacobian = Val(:true_jacobian)),
331-
Klement(; linsolve, precs),
330+
algs = (Broyden(; autodiff),
331+
Broyden(; init_jacobian = Val(:true_jacobian), autodiff),
332+
Klement(; linsolve, precs, autodiff),
332333
NewtonRaphson(; concrete_jac, linsolve, precs, autodiff),
333334
NewtonRaphson(; concrete_jac, linsolve, precs,
334335
linesearch = LineSearchesJL(; method = BackTracking()), autodiff),
@@ -343,7 +344,7 @@ end
343344

344345
"""
345346
FastShortcutNLLSPolyalg(::Type{T} = Float64; concrete_jac = nothing, linsolve = nothing,
346-
precs = DEFAULT_PRECS, kwargs...)
347+
precs = DEFAULT_PRECS, autodiff = nothing, kwargs...)
347348
348349
A polyalgorithm focused on balancing speed and robustness. It first tries less robust methods
349350
for more performance and then tries more robust techniques if the faster ones fail.
@@ -353,21 +354,25 @@ for more performance and then tries more robust techniques if the faster ones fa
353354
- `T`: The eltype of the initial guess. It is only used to check if some of the algorithms
354355
are compatible with the problem type. Defaults to `Float64`.
355356
"""
356-
function FastShortcutNLLSPolyalg(::Type{T} = Float64; concrete_jac = nothing,
357-
linsolve = nothing, precs = DEFAULT_PRECS, kwargs...) where {T}
357+
function FastShortcutNLLSPolyalg(
358+
::Type{T} = Float64; concrete_jac = nothing, linsolve = nothing,
359+
precs = DEFAULT_PRECS, autodiff = nothing, kwargs...) where {T}
358360
if __is_complex(T)
359-
algs = (GaussNewton(; concrete_jac, linsolve, precs, kwargs...),
360-
LevenbergMarquardt(; linsolve, precs, disable_geodesic = Val(true), kwargs...),
361-
LevenbergMarquardt(; linsolve, precs, kwargs...))
361+
algs = (GaussNewton(; concrete_jac, linsolve, precs, autodiff, kwargs...),
362+
LevenbergMarquardt(;
363+
linsolve, precs, autodiff, disable_geodesic = Val(true), kwargs...),
364+
LevenbergMarquardt(; linsolve, precs, autodiff, kwargs...))
362365
else
363-
algs = (GaussNewton(; concrete_jac, linsolve, precs, kwargs...),
364-
LevenbergMarquardt(; linsolve, precs, disable_geodesic = Val(true), kwargs...),
365-
TrustRegion(; concrete_jac, linsolve, precs, kwargs...),
366+
algs = (GaussNewton(; concrete_jac, linsolve, precs, autodiff, kwargs...),
367+
LevenbergMarquardt(;
368+
linsolve, precs, disable_geodesic = Val(true), autodiff, kwargs...),
369+
TrustRegion(; concrete_jac, linsolve, precs, autodiff, kwargs...),
366370
GaussNewton(; concrete_jac, linsolve, precs,
367-
linesearch = LineSearchesJL(; method = BackTracking()), kwargs...),
371+
linesearch = LineSearchesJL(; method = BackTracking()),
372+
autodiff, kwargs...),
368373
TrustRegion(; concrete_jac, linsolve, precs,
369-
radius_update_scheme = RadiusUpdateSchemes.Bastin, kwargs...),
370-
LevenbergMarquardt(; linsolve, precs, kwargs...))
374+
radius_update_scheme = RadiusUpdateSchemes.Bastin, autodiff, kwargs...),
375+
LevenbergMarquardt(; linsolve, precs, autodiff, kwargs...))
371376
end
372377
return NonlinearSolvePolyAlgorithm(algs, Val(:NLLS))
373378
end

test/misc/polyalg_tests.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,32 @@ end
7171
@test SciMLBase.successful_retcode(sol)
7272
end
7373

74+
@testitem "PolyAlgorithms Autodiff" begin
75+
cache = zeros(2)
76+
function f(du, u, p)
77+
cache .= u .* u
78+
du .= cache .- 2
79+
end
80+
u0 = [1.0, 1.0]
81+
probN = NonlinearProblem{true}(f, u0)
82+
83+
custom_polyalg = NonlinearSolvePolyAlgorithm((
84+
Broyden(; autodiff = AutoFiniteDiff()), LimitedMemoryBroyden()))
85+
86+
# Uses the `__solve` function
87+
solver = solve(probN; abstol = 1e-9)
88+
@test SciMLBase.successful_retcode(solver)
89+
@test_throws MethodError solve(probN, RobustMultiNewton(); abstol = 1e-9)
90+
@test SciMLBase.successful_retcode(solver)
91+
solver = solve(probN, RobustMultiNewton(; autodiff = AutoFiniteDiff()); abstol = 1e-9)
92+
@test SciMLBase.successful_retcode(solver)
93+
solver = solve(
94+
probN, FastShortcutNonlinearPolyalg(; autodiff = AutoFiniteDiff()); abstol = 1e-9)
95+
@test SciMLBase.successful_retcode(solver)
96+
solver = solve(probN, custom_polyalg; abstol = 1e-9)
97+
@test SciMLBase.successful_retcode(solver)
98+
end
99+
74100
@testitem "Simple Scalar Problem #187" begin
75101
using NaNMath
76102

0 commit comments

Comments
 (0)