Skip to content

Commit 52fd380

Browse files
Merge pull request #153 from HodgeLab/mb/immutable
add ImmutableNonlinearProblem
2 parents b186343 + b50726f commit 52fd380

16 files changed

+132
-40
lines changed

lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@ module SimpleNonlinearSolveChainRulesCoreExt
22

33
using ChainRulesCore: ChainRulesCore, NoTangent
44
using DiffEqBase: DiffEqBase
5-
using SciMLBase: ChainRulesOriginator, NonlinearProblem, NonlinearLeastSquaresProblem
6-
using SimpleNonlinearSolve: SimpleNonlinearSolve
5+
using SciMLBase: ChainRulesOriginator, NonlinearLeastSquaresProblem
6+
using SimpleNonlinearSolve: SimpleNonlinearSolve, ImmutableNonlinearProblem
77

88
# The expectation here is that no-one is using this directly inside a GPU kernel. We can
99
# eventually lift this requirement using a custom adjoint
1010
function ChainRulesCore.rrule(::typeof(SimpleNonlinearSolve.__internal_solve_up),
11-
prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
11+
prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem},
1212
sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...)
1313
out, ∇internal = DiffEqBase._solve_adjoint(
1414
prob, sensealg, u0, p, ChainRulesOriginator(), alg, args...; kwargs...)

lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@ module SimpleNonlinearSolveReverseDiffExt
33
using ArrayInterface: ArrayInterface
44
using DiffEqBase: DiffEqBase
55
using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal
6-
using SciMLBase: ReverseDiffOriginator, NonlinearProblem, NonlinearLeastSquaresProblem
7-
using SimpleNonlinearSolve: SimpleNonlinearSolve
6+
using SciMLBase: ReverseDiffOriginator, NonlinearLeastSquaresProblem
7+
using SimpleNonlinearSolve: SimpleNonlinearSolve, ImmutableNonlinearProblem
88
import SimpleNonlinearSolve: __internal_solve_up
99

10-
for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
10+
for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem)
1111
@eval begin
1212
function __internal_solve_up(prob::$(pType), sensealg, u0::TrackedArray, u0_changed,
1313
p::TrackedArray, p_changed, alg, args...; kwargs...)

lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
module SimpleNonlinearSolveTrackerExt
22

33
using DiffEqBase: DiffEqBase
4-
using SciMLBase: TrackerOriginator, NonlinearProblem, NonlinearLeastSquaresProblem, remake
5-
using SimpleNonlinearSolve: SimpleNonlinearSolve
4+
using SciMLBase: TrackerOriginator, NonlinearLeastSquaresProblem, remake
5+
using SimpleNonlinearSolve: SimpleNonlinearSolve, ImmutableNonlinearProblem
66
using Tracker: Tracker, TrackedArray
77

8-
for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
8+
for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem)
99
@eval begin
1010
function SimpleNonlinearSolve.__internal_solve_up(
1111
prob::$(pType), sensealg, u0::TrackedArray,

lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@ using LinearAlgebra: LinearAlgebra, I, convert, copyto!, diagind, dot, issuccess
1919
norm, transpose
2020
using MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
2121
using Reexport: @reexport
22-
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
22+
using SciMLBase: @add_kwonly, SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
23+
AbstractNonlinearFunction, StandardNonlinearProblem,
2324
NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem,
2425
ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm,
25-
build_solution, isinplace, _unwrap_val
26+
build_solution, isinplace, _unwrap_val, warn_paramtype
2627
using Setfield: @set!
2728
using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size
2829

@@ -35,7 +36,7 @@ abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorit
3536
abstract type AbstractNewtonAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end
3637

3738
@inline __is_extension_loaded(::Val) = false
38-
39+
include("immutable_nonlinear_problem.jl")
3940
include("utils.jl")
4041
include("linesearch.jl")
4142

@@ -70,6 +71,18 @@ end
7071
# By Pass the highlevel checks for NonlinearProblem for Simple Algorithms
7172
function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
7273
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
74+
prob = convert(ImmutableNonlinearProblem, prob)
75+
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
76+
sensealg = prob.kwargs[:sensealg]
77+
end
78+
new_u0 = u0 !== nothing ? u0 : prob.u0
79+
new_p = p !== nothing ? p : prob.p
80+
return __internal_solve_up(prob, sensealg, new_u0, u0 === nothing, new_p,
81+
p === nothing, alg, args...; prob.kwargs..., kwargs...)
82+
end
83+
84+
function SciMLBase.solve(prob::ImmutableNonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
85+
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
7386
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
7487
sensealg = prob.kwargs[:sensealg]
7588
end
@@ -79,7 +92,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::AbstractSimpleNonlinearSol
7992
p === nothing, alg, args...; prob.kwargs..., kwargs...)
8093
end
8194

82-
function __internal_solve_up(_prob::NonlinearProblem, sensealg, u0, u0_changed,
95+
function __internal_solve_up(_prob::ImmutableNonlinearProblem, sensealg, u0, u0_changed,
8396
p, p_changed, alg, args...; kwargs...)
8497
prob = u0_changed || p_changed ? remake(_prob; u0, p) : _prob
8598
return SciMLBase.__solve(prob, alg, args...; kwargs...)

lib/SimpleNonlinearSolve/src/ad.jl

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,26 @@
1-
for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
2-
@eval function SciMLBase.solve(
3-
prob::$(pType){<:Union{Number, <:AbstractArray}, iip,
4-
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
5-
alg::AbstractSimpleNonlinearSolveAlgorithm,
6-
args...;
7-
kwargs...) where {T, V, P, iip}
8-
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
9-
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
10-
return SciMLBase.build_solution(
11-
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
12-
end
1+
function SciMLBase.solve(
2+
prob::NonlinearLeastSquaresProblem{<:Union{Number, <:AbstractArray}, iip,
3+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
4+
alg::AbstractSimpleNonlinearSolveAlgorithm,
5+
args...;
6+
kwargs...) where {T, V, P, iip}
7+
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
8+
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
9+
return SciMLBase.build_solution(
10+
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
11+
end
12+
13+
function SciMLBase.solve(
14+
prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip,
15+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
16+
alg::AbstractSimpleNonlinearSolveAlgorithm,
17+
args...;
18+
kwargs...) where {T, V, P, iip}
19+
prob = convert(ImmutableNonlinearProblem, prob)
20+
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
21+
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
22+
return SciMLBase.build_solution(
23+
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
1324
end
1425

1526
for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
@@ -31,7 +42,7 @@ for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
3142
end
3243

3344
function __nlsolve_ad(
34-
prob::Union{IntervalNonlinearProblem, NonlinearProblem}, alg, args...; kwargs...)
45+
prob::Union{IntervalNonlinearProblem, NonlinearProblem, ImmutableNonlinearProblem}, alg, args...; kwargs...)
3546
p = value(prob.p)
3647
if prob isa IntervalNonlinearProblem
3748
tspan = value.(prob.tspan)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
struct ImmutableNonlinearProblem{uType, isinplace, P, F, K, PT} <:
2+
AbstractNonlinearProblem{uType, isinplace}
3+
f::F
4+
u0::uType
5+
p::P
6+
problem_type::PT
7+
kwargs::K
8+
@add_kwonly function ImmutableNonlinearProblem{iip}(f::AbstractNonlinearFunction{iip}, u0,
9+
p = NullParameters(),
10+
problem_type = StandardNonlinearProblem();
11+
kwargs...) where {iip}
12+
if haskey(kwargs, :p)
13+
error("`p` specified as a keyword argument `p = $(kwargs[:p])` to `NonlinearProblem`. This is not supported.")
14+
end
15+
warn_paramtype(p)
16+
new{typeof(u0), iip, typeof(p), typeof(f),
17+
typeof(kwargs), typeof(problem_type)}(f,
18+
u0,
19+
p,
20+
problem_type,
21+
kwargs)
22+
end
23+
24+
"""
25+
Define a steady state problem using the given function.
26+
`isinplace` optionally sets whether the function is inplace or not.
27+
This is determined automatically, but not inferred.
28+
"""
29+
function ImmutableNonlinearProblem{iip}(f, u0, p = NullParameters(); kwargs...) where {iip}
30+
ImmutableNonlinearProblem{iip}(NonlinearFunction{iip}(f), u0, p; kwargs...)
31+
end
32+
end
33+
34+
"""
35+
Define a nonlinear problem using an instance of
36+
[`AbstractNonlinearFunction`](@ref AbstractNonlinearFunction).
37+
"""
38+
function ImmutableNonlinearProblem(f::AbstractNonlinearFunction, u0, p = NullParameters(); kwargs...)
39+
ImmutableNonlinearProblem{isinplace(f)}(f, u0, p; kwargs...)
40+
end
41+
42+
function ImmutableNonlinearProblem(f, u0, p = NullParameters(); kwargs...)
43+
ImmutableNonlinearProblem(NonlinearFunction(f), u0, p; kwargs...)
44+
end
45+
46+
"""
47+
Define a ImmutableNonlinearProblem problem from SteadyStateProblem
48+
"""
49+
function ImmutableNonlinearProblem(prob::AbstractNonlinearProblem)
50+
ImmutableNonlinearProblem{isinplace(prob)}(prob.f, prob.u0, prob.p)
51+
end
52+
53+
54+
function Base.convert(::Type{ImmutableNonlinearProblem}, prob::T) where {T <: NonlinearProblem}
55+
ImmutableNonlinearProblem{isinplace(prob)}(prob.f,
56+
prob.u0,
57+
prob.p,
58+
prob.problem_type;
59+
prob.kwargs...)
60+
end
61+
62+
function DiffEqBase.get_concrete_problem(prob::ImmutableNonlinearProblem, isadapt; kwargs...)
63+
u0 = DiffEqBase.get_concrete_u0(prob, isadapt, nothing, kwargs)
64+
u0 = DiffEqBase.promote_u0(u0, prob.p, nothing)
65+
p = DiffEqBase.get_concrete_p(prob, kwargs)
66+
DiffEqBase.remake(prob; u0 = u0, p = p)
67+
end

lib/SimpleNonlinearSolve/src/nlsolve/broyden.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ end
2222

2323
__get_linesearch(::SimpleBroyden{LS}) where {LS} = Val(LS)
2424

25-
function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleBroyden, args...;
25+
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleBroyden, args...;
2626
abstol = nothing, reltol = nothing, maxiters = 1000,
2727
alias_u0 = false, termination_condition = nothing, kwargs...)
2828
x = __maybe_unaliased(prob.u0, alias_u0)

lib/SimpleNonlinearSolve/src/nlsolve/dfsane.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ function SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real =
5454
σ_min, σ_max, σ_1, γ, τ_min, τ_max, nexp, η_strategy)
5555
end
5656

57-
function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...;
57+
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleDFSane{M}, args...;
5858
abstol = nothing, reltol = nothing, maxiters = 1000, alias_u0 = false,
5959
termination_condition = nothing, kwargs...) where {M}
6060
x = __maybe_unaliased(prob.u0, alias_u0)

lib/SimpleNonlinearSolve/src/nlsolve/halley.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ A low-overhead implementation of Halley's Method.
2424
autodiff = nothing
2525
end
2626

27-
function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;
27+
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleHalley, args...;
2828
abstol = nothing, reltol = nothing, maxiters = 1000,
2929
alias_u0 = false, termination_condition = nothing, kwargs...)
3030
x = __maybe_unaliased(prob.u0, alias_u0)

lib/SimpleNonlinearSolve/src/nlsolve/klement.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ method is non-allocating on scalar and static array problems.
66
"""
77
struct SimpleKlement <: AbstractSimpleNonlinearSolveAlgorithm end
88

9-
function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleKlement, args...;
9+
function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleKlement, args...;
1010
abstol = nothing, reltol = nothing, maxiters = 1000,
1111
alias_u0 = false, termination_condition = nothing, kwargs...)
1212
x = __maybe_unaliased(prob.u0, alias_u0)

0 commit comments

Comments
 (0)