Skip to content

Commit 02160aa

Browse files
committed
Use DifferentiationInterface
1 parent b84ee8f commit 02160aa

File tree

7 files changed

+78
-230
lines changed

7 files changed

+78
-230
lines changed

lib/SimpleNonlinearSolve/Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
name = "SimpleNonlinearSolve"
22
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
33
authors = ["SciML"]
4-
version = "1.8.1"
4+
version = "1.9.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
99
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
1010
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
1111
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
12+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1213
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
1314
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1415
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
@@ -21,15 +22,13 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
2122

2223
[weakdeps]
2324
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
24-
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
2525
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
2626
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2727
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
2828
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2929

3030
[extensions]
3131
SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore"
32-
SimpleNonlinearSolvePolyesterForwardDiffExt = "PolyesterForwardDiff"
3332
SimpleNonlinearSolveReverseDiffExt = "ReverseDiff"
3433
SimpleNonlinearSolveStaticArraysExt = "StaticArrays"
3534
SimpleNonlinearSolveTrackerExt = "Tracker"
@@ -45,6 +44,7 @@ ChainRulesCore = "1.22"
4544
ConcreteStructs = "0.2.3"
4645
DiffEqBase = "6.149"
4746
DiffResults = "1.1"
47+
DifferentiationInterface = "0.4"
4848
ExplicitImports = "1.5.0"
4949
FastClosures = "0.3.2"
5050
FiniteDiff = "2.22"

lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolvePolyesterForwardDiffExt.jl

Lines changed: 0 additions & 20 deletions
This file was deleted.

lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@ module SimpleNonlinearSolve
33
using PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidations
44

55
@recompile_invalidations begin
6-
using ADTypes: ADTypes, AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff
6+
using ADTypes: ADTypes, AbstractADType, AutoFiniteDiff, AutoForwardDiff,
7+
AutoPolyesterForwardDiff
78
using ArrayInterface: ArrayInterface
89
using ConcreteStructs: @concrete
910
using DiffEqBase: DiffEqBase, AbstractNonlinearTerminationMode,
1011
AbstractSafeNonlinearTerminationMode,
1112
AbstractSafeBestNonlinearTerminationMode, AbsNormTerminationMode,
1213
NONLINEARSOLVE_DEFAULT_NORM
14+
using DifferentiationInterface: DifferentiationInterface
1315
using DiffResults: DiffResults
1416
using FastClosures: @closure
1517
using FiniteDiff: FiniteDiff
@@ -25,6 +27,8 @@ using PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidati
2527
using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size
2628
end
2729

30+
const DI = DifferentiationInterface
31+
2832
@reexport using SciMLBase
2933

3034
abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end

lib/SimpleNonlinearSolve/src/nlsolve/halley.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ A low-overhead implementation of Halley's Method.
1212
1313
### Keyword Arguments
1414
15-
- `autodiff`: determines the backend used for the Hessian. Defaults to `nothing`. Valid
16-
choices are `AutoForwardDiff()` or `AutoFiniteDiff()`.
15+
- `autodiff`: determines the backend used for the Hessian. Defaults to `nothing` (i.e.
16+
automatic backend selection). Valid choices include backends from
17+
`DifferentiationInterface.jl`.
1718
1819
!!! warning
1920
@@ -26,13 +27,11 @@ end
2627
function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;
2728
abstol = nothing, reltol = nothing, maxiters = 1000,
2829
alias_u0 = false, termination_condition = nothing, kwargs...)
29-
isinplace(prob) &&
30-
error("SimpleHalley currently only supports out-of-place nonlinear problems")
31-
3230
x = __maybe_unaliased(prob.u0, alias_u0)
3331
fx = _get_fx(prob, x)
3432
T = eltype(x)
3533

34+
f = __fixed_parameter_function(prob)
3635
autodiff = __get_concrete_autodiff(prob, alg.autodiff)
3736
abstol, reltol, tc_cache = init_termination_cache(
3837
prob, abstol, reltol, fx, x, termination_condition)
@@ -51,7 +50,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;
5150

5251
for i in 1:maxiters
5352
# Hessian Computation is unfortunately type unstable
54-
fx, dfx, d2fx = compute_jacobian_and_hessian(autodiff, prob, fx, x)
53+
fx, dfx, d2fx = compute_jacobian_and_hessian(autodiff, prob, f, fx, x)
5554
setindex_trait(x) === CannotSetindex() && (A = dfx)
5655

5756
# Factorize Once and Reuse
@@ -78,9 +77,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;
7877
cᵢ = _restructure(cᵢ, cᵢ_)
7978

8079
if i == 1
81-
if iszero(fx)
80+
iszero(fx) &&
8281
return build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
83-
end
8482
else
8583
# Termination Checks
8684
tc_sol = check_termination(tc_cache, fx, x, xo, prob, alg)

lib/SimpleNonlinearSolve/src/nlsolve/raphson.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ and static array problems.
1313
1414
### Keyword Arguments
1515
16-
- `autodiff`: determines the backend used for the Jacobian. Defaults to
17-
`nothing`. Valid choices are `AutoPolyesterForwardDiff()`, `AutoForwardDiff()` or
18-
`AutoFiniteDiff()`.
16+
- `autodiff`: determines the backend used for the Jacobian. Defaults to `nothing` (i.e.
17+
automatic backend selection). Valid choices include jacobian backends from
18+
`DifferentiationInterface.jl`.
1919
"""
2020
@kwdef @concrete struct SimpleNewtonRaphson <: AbstractNewtonAlgorithm
2121
autodiff = nothing
@@ -30,13 +30,14 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresPr
3030
fx = _get_fx(prob, x)
3131
autodiff = __get_concrete_autodiff(prob, alg.autodiff)
3232
@bb xo = copy(x)
33-
J, jac_cache = jacobian_cache(autodiff, prob.f, fx, x, prob.p)
33+
f = __fixed_parameter_function(prob)
34+
J, jac_cache = jacobian_cache(autodiff, prob, f, fx, x)
3435

3536
abstol, reltol, tc_cache = init_termination_cache(
3637
prob, abstol, reltol, fx, x, termination_condition)
3738

3839
for i in 1:maxiters
39-
fx, dfx = value_and_jacobian(autodiff, prob.f, fx, x, prob.p, jac_cache; J)
40+
fx, dfx = value_and_jacobian(autodiff, prob, f, fx, x, jac_cache; J)
4041

4142
if i == 1
4243
iszero(fx) && build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)

lib/SimpleNonlinearSolve/src/nlsolve/trustRegion.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ scalar and static array problems.
1010
1111
### Keyword Arguments
1212
13-
- `autodiff`: determines the backend used for the Jacobian. Defaults to
14-
`nothing`. Valid choices are `AutoPolyesterForwardDiff()`, `AutoForwardDiff()` or
15-
`AutoFiniteDiff()`.
13+
- `autodiff`: determines the backend used for the Jacobian. Defaults to `nothing` (i.e.
14+
automatic backend selection). Valid choices include jacobian backends from
15+
`DifferentiationInterface.jl`.
1616
- `max_trust_radius`: the maximum radius of the trust region. Defaults to
1717
`max(norm(f(u0)), maximum(u0) - minimum(u0))`.
1818
- `initial_trust_radius`: the initial trust region radius. Defaults to
@@ -85,8 +85,9 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleTrustRegion, args.
8585
fx = _get_fx(prob, x)
8686
norm_fx = norm(fx)
8787
@bb xo = copy(x)
88-
J, jac_cache = jacobian_cache(autodiff, prob.f, fx, x, prob.p)
89-
fx, ∇f = value_and_jacobian(autodiff, prob.f, fx, x, prob.p, jac_cache; J)
88+
f = __fixed_parameter_function(prob)
89+
J, jac_cache = jacobian_cache(autodiff, prob, f, fx, x)
90+
fx, ∇f = value_and_jacobian(autodiff, prob, f, fx, x, jac_cache; J)
9091

9192
abstol, reltol, tc_cache = init_termination_cache(
9293
prob, abstol, reltol, fx, x, termination_condition)
@@ -144,7 +145,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleTrustRegion, args.
144145
# Take the step.
145146
@bb @. xo = x
146147

147-
fx, ∇f = value_and_jacobian(autodiff, prob.f, fx, x, prob.p, jac_cache; J)
148+
fx, ∇f = value_and_jacobian(autodiff, prob, f, fx, x, jac_cache; J)
148149

149150
# Update the trust region radius.
150151
if !_unwrap_val(alg.nlsolve_update_rule) && r > η₃

0 commit comments

Comments
 (0)