Skip to content

Commit 07470f9

Browse files
Merge pull request #2309 from SciML/nonlinearsolve_diff
Split off Nonlinear Solvers and Differentiation
2 parents a386684 + f772585 commit 07470f9

File tree

46 files changed

+1144
-1016
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+1144
-1016
lines changed

lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,26 @@ import OrdinaryDiffEq: alg_order, calculate_residuals!,
99
OrdinaryDiffEqNewtonAlgorithm,
1010
AbstractController, DEFAULT_PRECS,
1111
CompiledFloats, uses_uprev,
12-
NLNewton, alg_cache, _vec, _reshape, @cache,
13-
isfsal, full_cache, build_nlsolver,
14-
nlsolve!, nlsolvefail, isnewton,
12+
alg_cache, _vec, _reshape, @cache,
13+
isfsal, full_cache,
1514
constvalue, isadaptive, error_constant,
16-
DIRK, set_new_W!, has_special_newton_error,
17-
du_alias_or_new, trivial_limiter!,
15+
has_special_newton_error,
16+
trivial_limiter!,
1817
ImplicitEulerConstantCache,
19-
compute_step!,
20-
ImplicitEulerCache, COEFFICIENT_MULTISTEP,
21-
markfirststage!, UJacobianWrapper, mul!,
18+
19+
ImplicitEulerCache,
2220
issplit, qsteady_min_default, qsteady_max_default,
2321
get_current_alg_order, get_current_adaptive_order,
2422
default_controller, stepsize_controller!, step_accept_controller!,
2523
step_reject_controller!, post_newton_controller!,
2624
u_modified!, DAEAlgorithm, _unwrap_val, DummyController
2725
using TruncatedStacktraces, MuladdMacro, MacroTools, FastBroadcast, RecursiveArrayTools
2826
import StaticArrays: SArray, MVector, SVector, @SVector, StaticArray, MMatrix, SA
29-
using LinearAlgebra: I
27+
using LinearAlgebra: mul!, I
3028
using ArrayInterface
29+
using OrdinaryDiffEq.OrdinaryDiffEqDifferentiation: UJacobianWrapper
30+
using OrdinaryDiffEq.OrdinaryDiffEqNonlinearSolve: NLNewton, du_alias_or_new, build_nlsolver,
31+
nlsolve!, nlsolvefail, isnewton, markfirststage!, set_new_W!, DIRK, compute_step!, COEFFICIENT_MULTISTEP
3132

3233
include("algorithms.jl")
3334
include("alg_utils.jl")
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
name = "OrdinaryDiffEqDifferentiation"
2+
uuid = "4302a76b-040a-498a-8c04-15b101fed76b"
3+
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>", "Yingbo Ma <mayingbo5@gmail.com>"]
4+
version = "1.0.0"
5+
6+
[deps]
7+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
8+
9+
[compat]
10+
julia = "1.10"
11+
12+
[extras]
13+
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
14+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
15+
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
16+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
17+
18+
[targets]
19+
test = ["DiffEqDevTools", "Random", "SafeTestsets", "Test"]
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
module OrdinaryDiffEqDifferentiation
2+
3+
import ADTypes: AutoFiniteDiff, AutoForwardDiff
4+
5+
import SparseDiffTools: SparseDiffTools, matrix_colors, forwarddiff_color_jacobian!,
6+
forwarddiff_color_jacobian, ForwardColorJacCache,
7+
default_chunk_size, getsize, JacVec
8+
9+
import ForwardDiff, FiniteDiff
10+
import ForwardDiff.Dual
11+
import LinearSolve
12+
import LinearSolve: OperatorAssumptions
13+
import FunctionWrappersWrappers
14+
using DiffEqBase
15+
16+
import LinearAlgebra
17+
import LinearAlgebra: Diagonal, I, UniformScaling, diagind, mul!,lmul!, axpby!, opnorm
18+
import SparseArrays: SparseMatrixCSC, AbstractSparseMatrix, nonzeros
19+
20+
import InteractiveUtils
21+
import ArrayInterface
22+
23+
import StaticArrayInterface
24+
import StaticArrays: SArray, MVector, SVector, @SVector, StaticArray, MMatrix, SA,
25+
StaticMatrix
26+
27+
using DiffEqBase: TimeGradientWrapper,
28+
UJacobianWrapper, TimeDerivativeWrapper,
29+
UDerivativeWrapper
30+
using SciMLBase: AbstractSciMLOperator
31+
using OrdinaryDiffEq: OrdinaryDiffEqAlgorithm, OrdinaryDiffEqAdaptiveImplicitAlgorithm, DAEAlgorithm,
32+
OrdinaryDiffEqImplicitAlgorithm, CompositeAlgorithm, OrdinaryDiffEqExponentialAlgorithm,
33+
OrdinaryDiffEqAdaptiveExponentialAlgorithm, @unpack, AbstractNLSolver, nlsolve_f, issplit,
34+
concrete_jac, unwrap_alg, OrdinaryDiffEqCache, _vec, standardtag, isnewton, _unwrap_val,
35+
set_new_W!, set_W_γdt!, alg_difftype, unwrap_cache, diffdir, get_W, isfirstcall, isfirststage, isJcurrent, get_new_W_γdt_cutoff,
36+
TryAgain, DIRK, COEFFICIENT_MULTISTEP, NORDSIECK_MULTISTEP, GLM, FastConvergence, Convergence, SlowConvergence, VerySlowConvergence, Divergence, NLStatus, MethodType, constvalue
37+
38+
import OrdinaryDiffEq: get_chunksize, resize_J_W!, resize_nlsolver!
39+
40+
import OrdinaryDiffEq: alg_autodiff
41+
42+
using FastBroadcast: @..
43+
44+
@static if isdefined(DiffEqBase, :OrdinaryDiffEqTag)
45+
import DiffEqBase: OrdinaryDiffEqTag
46+
else
47+
struct OrdinaryDiffEqTag end
48+
end
49+
50+
include("alg_utils.jl")
51+
include("linsolve_utils.jl")
52+
include("derivative_utils.jl")
53+
include("derivative_wrappers.jl")
54+
55+
end
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Extract AD type parameter from algorithm, returning as Val to ensure type stability for boolean options.
2+
function _alg_autodiff(alg::OrdinaryDiffEqAlgorithm)
3+
error("This algorithm does not have an autodifferentiation option defined.")
4+
end
5+
_alg_autodiff(::OrdinaryDiffEqAdaptiveImplicitAlgorithm{CS, AD}) where {CS, AD} = Val{AD}()
6+
_alg_autodiff(::DAEAlgorithm{CS, AD}) where {CS, AD} = Val{AD}()
7+
_alg_autodiff(::OrdinaryDiffEqImplicitAlgorithm{CS, AD}) where {CS, AD} = Val{AD}()
8+
_alg_autodiff(alg::CompositeAlgorithm) = _alg_autodiff(alg.algs[end])
9+
function _alg_autodiff(::Union{OrdinaryDiffEqExponentialAlgorithm{CS, AD},
10+
OrdinaryDiffEqAdaptiveExponentialAlgorithm{CS, AD}
11+
}) where {
12+
CS, AD
13+
}
14+
Val{AD}()
15+
end
16+
17+
function alg_autodiff(alg)
18+
autodiff = _alg_autodiff(alg)
19+
if autodiff == Val(false)
20+
return AutoFiniteDiff()
21+
elseif autodiff == Val(true)
22+
return AutoForwardDiff()
23+
else
24+
return _unwrap_val(autodiff)
25+
end
26+
end
27+
28+
Base.@pure function determine_chunksize(u, alg::DiffEqBase.DEAlgorithm)
29+
determine_chunksize(u, get_chunksize(alg))
30+
end
31+
Base.@pure function determine_chunksize(u, CS)
32+
if CS != 0
33+
return CS
34+
else
35+
return ForwardDiff.pickchunksize(length(u))
36+
end
37+
end
38+
39+
function DiffEqBase.prepare_alg(
40+
alg::Union{
41+
OrdinaryDiffEqAdaptiveImplicitAlgorithm{0, AD,
42+
FDT},
43+
OrdinaryDiffEqImplicitAlgorithm{0, AD, FDT},
44+
DAEAlgorithm{0, AD, FDT},
45+
OrdinaryDiffEqExponentialAlgorithm{0, AD, FDT}},
46+
u0::AbstractArray{T},
47+
p, prob) where {AD, FDT, T}
48+
49+
# If not using autodiff or norecompile mode or very large bitsize (like a dual number u0 already)
50+
# don't use a large chunksize as it will either error or not be beneficial
51+
if !(alg_autodiff(alg) isa AutoForwardDiff) ||
52+
(isbitstype(T) && sizeof(T) > 24) ||
53+
(prob.f isa ODEFunction &&
54+
prob.f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper)
55+
return remake(alg, chunk_size = Val{1}())
56+
end
57+
58+
L = StaticArrayInterface.known_length(typeof(u0))
59+
if L === nothing # dynamic sized
60+
# If chunksize is zero, pick chunksize right at the start of solve and
61+
# then do function barrier to infer the full solve
62+
x = if prob.f.colorvec === nothing
63+
length(u0)
64+
else
65+
maximum(prob.f.colorvec)
66+
end
67+
68+
cs = ForwardDiff.pickchunksize(x)
69+
return remake(alg, chunk_size = Val{cs}())
70+
else # statically sized
71+
cs = pick_static_chunksize(Val{L}())
72+
return remake(alg, chunk_size = cs)
73+
end
74+
end
75+
76+
@generated function pick_static_chunksize(::Val{chunksize}) where {chunksize}
77+
x = ForwardDiff.pickchunksize(chunksize)
78+
:(Val{$x}())
79+
end

src/derivative_utils.jl renamed to lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,8 @@ function jacobian2W(mass_matrix, dtgamma::Number, J::AbstractMatrix,
626626
return W
627627
end
628628

629+
is_always_new(alg) = isdefined(alg, :always_new) ? alg.always_new : false
630+
629631
function calc_W!(W, integrator, nlsolver::Union{Nothing, AbstractNLSolver}, cache, dtgamma,
630632
repeat_step, W_transform = false, newJW = nothing)
631633
@unpack t, dt, uprev, u, f, p = integrator
@@ -920,7 +922,7 @@ function LinearSolve.init_cacheval(
920922
assumptions::OperatorAssumptions)
921923
end
922924

923-
for alg in InteractiveUtils.subtypes(OrdinaryDiffEq.LinearSolve.AbstractFactorization)
925+
for alg in InteractiveUtils.subtypes(LinearSolve.AbstractFactorization)
924926
@eval function LinearSolve.init_cacheval(alg::$alg, A::WOperator, b, u, Pl, Pr,
925927
maxiters::Int, abstol, reltol, verbose::Bool,
926928
assumptions::OperatorAssumptions)
@@ -929,3 +931,39 @@ for alg in InteractiveUtils.subtypes(OrdinaryDiffEq.LinearSolve.AbstractFactoriz
929931
assumptions::OperatorAssumptions)
930932
end
931933
end
934+
935+
function resize_J_W!(cache, integrator, i)
936+
(isdefined(cache, :J) && isdefined(cache, :W)) || return
937+
938+
@unpack f = integrator
939+
940+
if cache.W isa WOperator
941+
nf = nlsolve_f(f, integrator.alg)
942+
islin = f isa Union{ODEFunction, SplitFunction} && islinear(nf.f)
943+
if !islin
944+
if cache.J isa AbstractSciMLOperator
945+
resize!(cache.J, i)
946+
elseif f.jac_prototype !== nothing
947+
J = similar(f.jac_prototype, i, i)
948+
J = MatrixOperator(J; update_func! = f.jac)
949+
end
950+
if cache.W.jacvec isa AbstractSciMLOperator
951+
resize!(cache.W.jacvec, i)
952+
end
953+
cache.W = WOperator{DiffEqBase.isinplace(integrator.sol.prob)}(f.mass_matrix,
954+
integrator.dt,
955+
cache.J,
956+
integrator.u,
957+
cache.W.jacvec;
958+
transform = cache.W.transform)
959+
cache.J = cache.W.J
960+
end
961+
else
962+
if cache.J !== nothing
963+
cache.J = similar(cache.J, i, i)
964+
end
965+
cache.W = similar(cache.W, i, i)
966+
end
967+
968+
nothing
969+
end
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
issuccess_W(W::LinearAlgebra.Factorization) = LinearAlgebra.issuccess(W)
2+
issuccess_W(W::Number) = !iszero(W)
3+
issuccess_W(::Any) = true
4+
5+
function dolinsolve(integrator, linsolve; A = nothing, linu = nothing, b = nothing,
6+
du = nothing, u = nothing, p = nothing, t = nothing,
7+
weight = nothing, solverdata = nothing,
8+
reltol = integrator === nothing ? nothing : integrator.opts.reltol)
9+
A !== nothing && (linsolve.A = A)
10+
b !== nothing && (linsolve.b = b)
11+
linu !== nothing && (linsolve.u = linu)
12+
13+
Plprev = linsolve.Pl isa LinearSolve.ComposePreconditioner ? linsolve.Pl.outer :
14+
linsolve.Pl
15+
Prprev = linsolve.Pr isa LinearSolve.ComposePreconditioner ? linsolve.Pr.outer :
16+
linsolve.Pr
17+
18+
_alg = unwrap_alg(integrator, true)
19+
20+
_Pl, _Pr = _alg.precs(linsolve.A, du, u, p, t, A !== nothing, Plprev, Prprev,
21+
solverdata)
22+
if (_Pl !== nothing || _Pr !== nothing)
23+
__Pl = _Pl === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pl
24+
__Pr = _Pr === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pr
25+
linsolve.Pl = __Pl
26+
linsolve.Pr = __Pr
27+
end
28+
29+
linres = solve!(linsolve; reltol)
30+
31+
# TODO: this ignores the add of the `f` count for add_steps!
32+
if integrator isa SciMLBase.DEIntegrator && _alg.linsolve !== nothing &&
33+
!LinearSolve.needs_concrete_A(_alg.linsolve) &&
34+
linsolve.A isa WOperator && linsolve.A.J isa AbstractSciMLOperator
35+
if alg_autodiff(_alg) isa AutoForwardDiff
36+
integrator.stats.nf += linres.iters
37+
elseif alg_autodiff(_alg) isa AutoFiniteDiff
38+
integrator.stats.nf += 2 * linres.iters
39+
else
40+
error("$alg_autodiff not yet supported in dolinsolve function")
41+
end
42+
end
43+
44+
return linres
45+
end
46+
47+
function wrapprecs(_Pl::Nothing, _Pr::Nothing, weight, u)
48+
Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight)))
49+
Pr = Diagonal(_vec(weight))
50+
Pl, Pr
51+
end
52+
53+
function wrapprecs(_Pl, _Pr, weight, u)
54+
Pl = _Pl === nothing ? SciMLOperators.IdentityOperator(length(u)) : _Pl
55+
Pr = _Pr === nothing ? SciMLOperators.IdentityOperator(length(u)) : _Pr
56+
Pl, Pr
57+
end
58+
59+
Base.resize!(p::LinearSolve.LinearCache, i) = p

lib/OrdinaryDiffEqDifferentiation/test/runtests.jl

Whitespace-only changes.

lib/OrdinaryDiffEqExponentialRK/src/OrdinaryDiffEqExponentialRK.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ module OrdinaryDiffEqExponentialRK
22

33
import OrdinaryDiffEq: alg_order, alg_adaptive_order, ismultistep, OrdinaryDiffEqExponentialAlgorithm,
44
_unwrap_val, OrdinaryDiffEqMutableCache, OrdinaryDiffEqConstantCache,
5-
build_jac_config, UJacobianWrapper, @cache, alg_cache, UDerivativeWrapper,
6-
initialize!, perform_step!, @unpack, unwrap_alg, calc_J, calc_J!,
5+
@cache, alg_cache,
6+
initialize!, perform_step!, @unpack, unwrap_alg,
77
OrdinaryDiffEqAdaptiveExponentialAlgorithm, CompositeAlgorithm,
88
ExponentialAlgorithm, fsal_typeof, isdtchangeable, calculate_residuals, calculate_residuals!
99
using RecursiveArrayTools
@@ -12,6 +12,7 @@ using LinearAlgebra: axpy!, mul!
1212
using DiffEqBase, SciMLBase
1313
using ExponentialUtilities
1414
import RecursiveArrayTools: recursivecopy!
15+
using OrdinaryDiffEq.OrdinaryDiffEqDifferentiation: build_jac_config, UJacobianWrapper, UDerivativeWrapper, calc_J, calc_J!
1516

1617
include("algorithms.jl")
1718
include("alg_utils.jl")

lib/OrdinaryDiffEqExtrapolation/src/OrdinaryDiffEqExtrapolation.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ import OrdinaryDiffEq: alg_order, alg_maximum_order, get_current_adaptive_order,
1414
DEFAULT_PRECS,
1515
constvalue, PolyesterThreads, Sequential, BaseThreads,
1616
_digest_beta1_beta2, timedepentdtmin, _unwrap_val,
17-
TimeDerivativeWrapper, UDerivativeWrapper, calc_J, _reshape, _vec,
18-
WOperator, TimeGradientWrapper, UJacobianWrapper, build_grad_config,
19-
build_jac_config, calc_J!, jacobian2W!, dolinsolve
17+
_reshape, _vec
2018
using DiffEqBase, FastBroadcast, Polyester, MuladdMacro, RecursiveArrayTools, LinearSolve
19+
import OrdinaryDiffEq.OrdinaryDiffEqDifferentiation: TimeDerivativeWrapper, UDerivativeWrapper, calc_J, WOperator, TimeGradientWrapper, UJacobianWrapper, build_grad_config,
20+
build_jac_config, calc_J!, jacobian2W!, dolinsolve
2121

2222
include("algorithms.jl")
2323
include("alg_utils.jl")

0 commit comments

Comments
 (0)