Skip to content

Commit 534c41c

Browse files
authored
Merge pull request #361 from oscardssmith/fix-TrustRegion-autodiff
fix `autodiff=AutoFiniteDiff()` for `TrustRegion`
2 parents aecd854 + cc3ebb7 commit 534c41c

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

src/algorithms/trust_region.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,16 @@ function TrustRegion(; concrete_jac = nothing, linsolve = nothing, precs = DEFAU
2424
initial_trust_radius::Real = 0 // 1, step_threshold::Real = 1 // 10000,
2525
shrink_threshold::Real = 1 // 4, expand_threshold::Real = 3 // 4,
2626
shrink_factor::Real = 1 // 4, expand_factor::Real = 2 // 1,
27-
max_shrink_times::Int = 32, vjp_autodiff = nothing, autodiff = nothing)
27+
max_shrink_times::Int = 32, autodiff = nothing, vjp_autodiff = nothing)
2828
descent = Dogleg(; linsolve, precs)
29-
forward_ad = autodiff isa ADTypes.AbstractForwardMode ? autodiff : nothing
29+
if autodiff isa Union{ADTypes.AbstractForwardMode, ADTypes.AbstractFiniteDifferencesMode}
30+
forward_ad = autodiff
31+
else
32+
forward_ad = nothing
33+
end
34+
if isnothing(vjp_autodiff) && autodiff isa ADTypes.AbstractFiniteDifferencesMode
35+
vjp_autodiff = autodiff
36+
end
3037
trustregion = GenericTrustRegionScheme(; method = radius_update_scheme, step_threshold,
3138
shrink_threshold, expand_threshold, shrink_factor, expand_factor,
3239
reverse_ad = vjp_autodiff, forward_ad)

test/misc/polyalgs.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,14 @@ end
9393
maxiters = 10)
9494
end
9595

96+
no_ad_fast = FastShortcutNonlinearPolyalg(autodiff=AutoFiniteDiff())
97+
no_ad_robust = RobustMultiNewton(autodiff=AutoFiniteDiff())
98+
no_ad_algs = Set([no_ad_fast, no_ad_robust, no_ad_fast.algs..., no_ad_robust.algs...])
9699
@testset "[IIP] no AD" begin
97100
f_iip = Base.Experimental.@opaque (du, u, p) -> du .= u .* u .- p
98-
u0 = [0.0]
101+
u0 = [0.5]
99102
prob = NonlinearProblem(f_iip, u0, 1.0)
100-
for alg in [RobustMultiNewton(autodiff = AutoFiniteDiff())]
103+
for alg in no_ad_algs
101104
sol = solve(prob, alg)
102105
@test isapprox(only(sol.u), 1.0)
103106
@test SciMLBase.successful_retcode(sol.retcode)
@@ -106,9 +109,9 @@ end
106109

107110
@testset "[OOP] no AD" begin
108111
f_oop = Base.Experimental.@opaque (u, p) -> u .* u .- p
109-
u0 = [0.0]
112+
u0 = [0.5]
110113
prob = NonlinearProblem{false}(f_oop, u0, 1.0)
111-
for alg in [RobustMultiNewton(autodiff = AutoFiniteDiff())]
114+
for alg in no_ad_algs
112115
sol = solve(prob, alg)
113116
@test isapprox(only(sol.u), 1.0)
114117
@test SciMLBase.successful_retcode(sol.retcode)

0 commit comments

Comments
 (0)