Skip to content

Commit 138de9b

Browse files
Merge pull request #347 from oscardssmith/os/robustmultinewton-autodiff
make RobustMultiNewton always respect autodiff choice
2 parents d47b131 + 8614ebe commit 138de9b

File tree

6 files changed

+34
-7
lines changed

6 files changed

+34
-7
lines changed

docs/src/solvers/FixedPointSolvers.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,4 @@ robust.
5050

5151
### SIAMFANLEquations.jl
5252

53-
- `SIAMFANLEquationsJL(; method = :anderson)`: Anderson acceleration for fixed point problems.
53+
- `SIAMFANLEquationsJL(; method = :anderson)`: Anderson acceleration for fixed point problems.

ext/NonlinearSolveSIAMFANLEquationsExt.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg
4343
rtol = NonlinearSolve.DEFAULT_TOLERANCE(reltol, T)
4444

4545
if prob.u0 isa Number
46-
f = method == :anderson ? (du, u) -> (du = prob.f(u, prob.p)) : ((u) -> prob.f(u, prob.p))
46+
f = method == :anderson ? (du, u) -> (du = prob.f(u, prob.p)) :
47+
((u) -> prob.f(u, prob.p))
4748

4849
if method == :newton
4950
sol = nsolsc(f, prob.u0; maxit = maxiters, atol, rtol, printerr = ShT)
@@ -55,7 +56,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg
5556
elseif method == :anderson
5657
f, u = NonlinearSolve.__construct_f(prob; alias_u0,
5758
make_fixed_point = Val(true), can_handle_arbitrary_dims = Val(true))
58-
sol = aasol(f, [prob.u0], m, __zeros_like(u, 1, 2*m+4); maxit = maxiters,
59+
sol = aasol(f, [prob.u0], m, __zeros_like(u, 1, 2 * m + 4); maxit = maxiters,
5960
atol, rtol, beta = beta)
6061
end
6162

@@ -110,7 +111,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg
110111
elseif method == :anderson
111112
f!, u = NonlinearSolve.__construct_f(prob; alias_u0,
112113
can_handle_arbitrary_dims = Val(true), make_fixed_point = Val(true))
113-
sol = aasol(f!, u, m, zeros(T, N, 2*m+4), atol = atol, rtol = rtol,
114+
sol = aasol(f!, u, m, zeros(T, N, 2 * m + 4), atol = atol, rtol = rtol,
114115
maxit = maxiters, beta = beta)
115116
end
116117
else

src/default.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ function RobustMultiNewton(::Type{T} = Float64; concrete_jac = nothing, linsolve
207207
# Let's atleast have something here for complex numbers
208208
algs = (NewtonRaphson(; concrete_jac, linsolve, precs, autodiff),)
209209
else
210-
algs = (TrustRegion(; concrete_jac, linsolve, precs),
210+
algs = (TrustRegion(; concrete_jac, linsolve, precs, autodiff),
211211
TrustRegion(; concrete_jac, linsolve, precs, autodiff,
212212
radius_update_scheme = RadiusUpdateSchemes.Bastin),
213213
NewtonRaphson(; concrete_jac, linsolve, precs, linesearch = BackTracking(),

test/misc/no_ad.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
using LinearAlgebra, NonlinearSolve, Test
2+
3+
@testset "[IIP] no AD" begin
4+
f_iip = Base.Experimental.@opaque (du, u, p) -> du .= u .* u .- p
5+
u0 = [0.0]
6+
prob = NonlinearProblem(f_iip, u0, 1.0)
7+
for alg in [RobustMultiNewton(autodiff = AutoFiniteDiff()())]
8+
sol = solve(prob, alg)
9+
@test isapprox(only(sol.u), 1.0)
10+
@test SciMLBase.successful_retcode(sol.retcode)
11+
end
12+
end
13+
14+
@testset "[OOP] no AD" begin
15+
f_oop = Base.Experimental.@opaque (u, p) -> u .* u .- p
16+
u0 = [0.0]
17+
prob = NonlinearProblem{false}(f_oop, u0, 1.0)
18+
for alg in [RobustMultiNewton(autodiff = AutoFiniteDiff())]
19+
sol = solve(prob, alg)
20+
@test isapprox(only(sol.u), 1.0)
21+
@test SciMLBase.successful_retcode(sol.retcode)
22+
end
23+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ end
3535
@time @safetestset "Matrix Resizing" include("misc/matrix_resizing.jl")
3636
@time @safetestset "Infeasible Problems" include("misc/infeasible.jl")
3737
@time @safetestset "Banded Matrices" include("misc/banded_matrices.jl")
38+
@time @safetestset "No AD" include("misc/no_ad.jl")
3839
end
3940

4041
if GROUP == "GPU"

test/wrappers/fixedpoint.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using NonlinearSolve, FixedPointAcceleration, SpeedMapping, NLsolve, SIAMFANLEquations, LinearAlgebra, Test
1+
using NonlinearSolve,
2+
FixedPointAcceleration, SpeedMapping, NLsolve, SIAMFANLEquations, LinearAlgebra, Test
23

34
# Simple Scalar Problem
45
@testset "Simple Scalar Problem" begin
@@ -29,7 +30,8 @@ end
2930
@test maximum(abs.(solve(prob, SpeedMappingJL()).resid)) 1e-10
3031
@test maximum(abs.(solve(prob, SpeedMappingJL(; orders = [3, 2])).resid)) 1e-10
3132
@test maximum(abs.(solve(prob, SpeedMappingJL(; stabilize = true)).resid)) 1e-10
32-
@test maximum(abs.(solve(prob, SIAMFANLEquationsJL(; method = :anderson)).resid)) 1e-10
33+
@test maximum(abs.(solve(prob, SIAMFANLEquationsJL(; method = :anderson)).resid))
34+
1e-10
3335

3436
@test_broken maximum(abs.(solve(prob, NLsolveJL(; method = :anderson)).resid)) 1e-10
3537
end

0 commit comments

Comments
 (0)