Skip to content

Commit b37a5cc

Browse files
Merge pull request #2544 from AayushSabharwal/as/scc-init
feat: allow using `SCCNonlinearProblem` for initialization
2 parents 14e6cb6 + 074b1ca commit b37a5cc

File tree

5 files changed

+18
-12
lines changed

5 files changed

+18
-12
lines changed

.github/workflows/Downstream.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ jobs:
3030
- {user: SciML, repo: SciMLSensitivity.jl, group: Core3}
3131
- {user: SciML, repo: SciMLSensitivity.jl, group: Core4}
3232
- {user: SciML, repo: SciMLSensitivity.jl, group: Core5}
33-
- {user: SciML, repo: ModelingToolkit.jl, group: All}
33+
- {user: SciML, repo: ModelingToolkit.jl, group: InterfaceI}
34+
- {user: SciML, repo: ModelingToolkit.jl, group: InterfaceII}
35+
- {user: SciML, repo: ModelingToolkit.jl, group: Initialization}
36+
- {user: SciML, repo: ModelingToolkit.jl, group: SymbolicIndexingInterface}
3437
- {user: SciML, repo: DiffEqDevTools.jl, group: Core}
3538
- {user: nathanaelbosch, repo: ProbNumDiffEq.jl, group: Downstream}
3639
- {user: SKopecz, repo: PositiveIntegrators.jl, group: Downstream}

lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ using FastBroadcast: @.., True, False
6262

6363
using SciMLBase: NoInit, CheckInit, OverrideInit, AbstractDEProblem, _unwrap_val
6464

65-
import SciMLBase: alg_order
65+
import SciMLBase: AbstractNonlinearProblem, alg_order
6666

6767
import DiffEqBase: calculate_residuals,
6868
calculate_residuals!, unwrap_cache,
@@ -76,7 +76,8 @@ import Accessors: @reset
7676

7777
using SciMLStructures: canonicalize, Tunable, isscimlstructure
7878

79-
using SymbolicIndexingInterface: parameter_values, is_variable, variable_index, symbolic_type, NotSymbolic
79+
using SymbolicIndexingInterface: state_values, parameter_values, is_variable, variable_index,
80+
symbolic_type, NotSymbolic
8081

8182
const CompiledFloats = Union{Float32, Float64}
8283
import Preferences

lib/OrdinaryDiffEqCore/src/initialize_dae.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ default_nlsolve(alg, isinplace, u, initprob, autodiff = false) = alg
101101

102102
## If the initialization is trivial just use nothing alg
103103
function default_nlsolve(
104-
::Nothing, isinplace::Val{true}, u::Nothing, ::NonlinearProblem, autodiff = false)
104+
::Nothing, isinplace::Val{true}, u::Nothing, ::AbstractNonlinearProblem, autodiff = false)
105105
nothing
106106
end
107107

@@ -111,7 +111,7 @@ function default_nlsolve(
111111
end
112112

113113
function default_nlsolve(
114-
::Nothing, isinplace::Val{false}, u::Nothing, ::NonlinearProblem, autodiff = false)
114+
::Nothing, isinplace::Val{false}, u::Nothing, ::AbstractNonlinearProblem, autodiff = false)
115115
nothing
116116
end
117117

@@ -122,7 +122,7 @@ function default_nlsolve(
122122
end
123123

124124
function OrdinaryDiffEqCore.default_nlsolve(
125-
::Nothing, isinplace, u, ::NonlinearProblem, autodiff = false)
125+
::Nothing, isinplace, u, ::AbstractNonlinearProblem, autodiff = false)
126126
error("This ODE requires a DAE initialization and thus a nonlinear solve but no nonlinear solve has been loaded. To solve this problem, do `using OrdinaryDiffEqNonlinearSolve` or pass a custom `nlsolve` choice into the `initializealg`.")
127127
end
128128

@@ -146,15 +146,16 @@ function _initialize_dae!(integrator, prob::AbstractDEProblem,
146146
# If it doesn't have autodiff, assume it comes from symbolic system like ModelingToolkit
147147
# Since then it's the case of not a DAE but has initializeprob
148148
# In which case, it should be differentiable
149-
isAD = if initializeprob.u0 === nothing
149+
iu0 = state_values(initializeprob)
150+
isAD = if iu0 === nothing
150151
AutoForwardDiff
151152
elseif has_autodiff(integrator.alg)
152153
alg_autodiff(integrator.alg) isa AutoForwardDiff
153154
else
154155
true
155156
end
156157

157-
nlsolve_alg = default_nlsolve(alg.nlsolve, isinplace, initializeprob.u0, initializeprob, isAD)
158+
nlsolve_alg = default_nlsolve(alg.nlsolve, isinplace, iu0, initializeprob, isAD)
158159

159160
u0, p, success = SciMLBase.get_initial_values(prob, integrator, prob.f, alg, isinplace; nlsolve_alg, abstol = integrator.opts.abstol, reltol = integrator.opts.reltol)
160161

lib/OrdinaryDiffEqNonlinearSolve/src/OrdinaryDiffEqNonlinearSolve.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ import SciMLBase
66
import SciMLBase: init, solve, solve!, remake
77
using SciMLBase: DAEFunction, DEIntegrator, NonlinearFunction, NonlinearProblem,
88
NonlinearLeastSquaresProblem, LinearProblem, ODEProblem, DAEProblem,
9-
update_coefficients!, get_tmp_cache, AbstractSciMLOperator, ReturnCode
9+
update_coefficients!, get_tmp_cache, AbstractSciMLOperator, ReturnCode,
10+
AbstractNonlinearProblem
1011
import DiffEqBase
1112
import PreallocationTools
1213
using SimpleNonlinearSolve: SimpleTrustRegion, SimpleGaussNewton

lib/OrdinaryDiffEqNonlinearSolve/src/initialize_dae.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
function default_nlsolve(
2-
::Nothing, isinplace::Val{true}, u, ::NonlinearProblem, autodiff = false)
2+
::Nothing, isinplace::Val{true}, u, ::AbstractNonlinearProblem, autodiff = false)
33
FastShortcutNonlinearPolyalg(;
44
autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
55
end
@@ -8,7 +8,7 @@ function default_nlsolve(
88
FastShortcutNLLSPolyalg(; autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
99
end
1010
function default_nlsolve(
11-
::Nothing, isinplace::Val{false}, u, ::NonlinearProblem, autodiff = false)
11+
::Nothing, isinplace::Val{false}, u, ::AbstractNonlinearProblem, autodiff = false)
1212
FastShortcutNonlinearPolyalg(;
1313
autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
1414
end
@@ -17,7 +17,7 @@ function default_nlsolve(
1717
FastShortcutNLLSPolyalg(; autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
1818
end
1919
function default_nlsolve(::Nothing, isinplace::Val{false}, u::StaticArray,
20-
::NonlinearProblem, autodiff = false)
20+
::AbstractNonlinearProblem, autodiff = false)
2121
SimpleTrustRegion(autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
2222
end
2323
function default_nlsolve(::Nothing, isinplace::Val{false}, u::StaticArray,

0 commit comments

Comments
 (0)