Skip to content

Commit 1650ddf

Browse files
Merge branch 'master' into compathelper/new_version/2024-11-22-03-28-35-271-01723128443
2 parents b0b00af + c39a4d4 commit 1650ddf

File tree

7 files changed

+184
-12
lines changed

7 files changed

+184
-12
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SciMLBase"
22
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
33
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com> and contributors"]
4-
version = "2.63.1"
4+
version = "2.64.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/problems/nonlinear_problems.jl

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,64 @@ Note that this example aliases the parameters together for a memory-reduced repr
462462
* `probs`: the collection of problems to solve
463463
* `explictfuns!`: the explicit functions for mutating the parameter set
464464
"""
465-
mutable struct SCCNonlinearProblem{P, E}
465+
mutable struct SCCNonlinearProblem{uType, iip, P, E, I, Par} <:
466+
AbstractNonlinearProblem{uType, iip}
466467
probs::P
467-
explictfuns!::E
468+
explicitfuns!::E
469+
full_index_provider::I
470+
parameter_object::Par
471+
parameters_alias::Bool
472+
473+
function SCCNonlinearProblem{P, E, I, Par}(
474+
probs::P, funs::E, indp::I, pobj::Par, alias::Bool) where {P, E, I, Par}
475+
u0 = mapreduce(state_values, vcat, probs)
476+
uType = typeof(u0)
477+
new{uType, false, P, E, I, Par}(probs, funs, indp, pobj, alias)
478+
end
479+
end
480+
481+
function SCCNonlinearProblem(probs, explicitfuns!, full_index_provider = nothing,
482+
parameter_object = nothing, parameters_alias = false)
483+
return SCCNonlinearProblem{typeof(probs), typeof(explicitfuns!),
484+
typeof(full_index_provider), typeof(parameter_object)}(
485+
probs, explicitfuns!, full_index_provider, parameter_object, parameters_alias)
486+
end
487+
488+
function Base.getproperty(prob::SCCNonlinearProblem, name::Symbol)
489+
if name == :explictfuns!
490+
return getfield(prob, :explicitfuns!)
491+
elseif name == :ps
492+
return ParameterIndexingProxy(prob)
493+
end
494+
return getfield(prob, name)
495+
end
496+
497+
function SymbolicIndexingInterface.symbolic_container(prob::SCCNonlinearProblem)
498+
prob.full_index_provider
499+
end
500+
function SymbolicIndexingInterface.parameter_values(prob::SCCNonlinearProblem)
501+
prob.parameter_object
502+
end
503+
function SymbolicIndexingInterface.state_values(prob::SCCNonlinearProblem)
504+
mapreduce(state_values, vcat, prob.probs)
505+
end
506+
507+
function SymbolicIndexingInterface.set_state!(prob::SCCNonlinearProblem, val, idx)
508+
for scc in prob.probs
509+
svals = state_values(scc)
510+
checkbounds(Bool, svals, idx) && return set_state!(scc, val, idx)
511+
idx -= length(svals)
512+
end
513+
throw(BoundsError(state_values(prob), idx))
514+
end
515+
516+
function SymbolicIndexingInterface.set_parameter!(prob::SCCNonlinearProblem, val, idx)
517+
if prob.parameter_object !== nothing
518+
set_parameter!(prob.parameter_object, val, idx)
519+
prob.parameters_alias && return
520+
end
521+
for scc in prob.probs
522+
is_parameter(scc, idx) || continue
523+
set_parameter!(scc, val, idx)
524+
end
468525
end

src/problems/problem_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ function Base.show(io::IO, mime::MIME"text/plain", A::AbstractNonlinearProblem)
6060
summary(io, A)
6161
println(io)
6262
print(io, "u0: ")
63-
show(io, mime, A.u0)
63+
show(io, mime, state_values(A))
6464
end
6565

6666
function Base.show(io::IO, mime::MIME"text/plain", A::IntervalNonlinearProblem)

src/remake.jl

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ function remake(prob::ODEProblem; f = missing,
125125

126126
if f === missing
127127
if build_initializeprob
128-
initialization_data = remake_initialization_data(
129-
prob.f.sys, prob.f, u0, tspan[1], p)
128+
initialization_data = remake_initialization_data_compat_wrapper(
129+
prob.f.sys, prob.f, u0, tspan[1], p, newu0, newp)
130130
else
131131
initialization_data = nothing
132132
end
@@ -203,16 +203,32 @@ function remake_initializeprob(sys, scimlfn, u0, t0, p)
203203
end
204204

205205
"""
206-
remake_initialization_data(sys, scimlfn, u0, t0, p)
206+
$(TYPEDSIGNATURES)
207+
208+
Wrapper around `remake_initialization_data` for backward compatibility when `newu0` and
209+
`newp` were not arguments.
210+
"""
211+
function remake_initialization_data_compat_wrapper(sys, scimlfn, u0, t0, p, newu0, newp)
212+
if hasmethod(remake_initialization_data,
213+
Tuple{typeof(sys), typeof(scimlfn), typeof(u0), typeof(t0), typeof(p)})
214+
remake_initialization_data(sys, scimlfn, u0, t0, p)
215+
else
216+
remake_initialization_data(sys, scimlfn, u0, t0, p, newu0, newp)
217+
end
218+
end
219+
220+
"""
221+
remake_initialization_data(sys, scimlfn, u0, t0, p, newu0, newp)
207222
208223
Re-create the initialization data present in the function `scimlfn`, using the
209-
associated system `sys` and the user provided new values of `u0`, initial time `t0` and
210-
`p`. By default, this calls `remake_initializeprob` for backward compatibility and
211-
attempts to construct an `OverrideInitData` from the result.
224+
associated system `sys`, the user provided new values of `u0`, initial time `t0`,
225+
user-provided `p`, new u0 vector `newu0` and new parameter object `newp`. By default,
226+
this calls `remake_initializeprob` for backward compatibility and attempts to construct
227+
an `OverrideInitData` from the result.
212228
213229
Note that `u0` or `p` may be `missing` if the user does not provide a value for them.
214230
"""
215-
function remake_initialization_data(sys, scimlfn, u0, t0, p)
231+
function remake_initialization_data(sys, scimlfn, u0, t0, p, newu0, newp)
216232
return reconstruct_initialization_data(
217233
nothing, remake_initializeprob(sys, scimlfn, u0, t0, p)...)
218234
end

test/downstream/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ ModelingToolkitStandardLibrary = "2.7"
3838
NonlinearSolve = "2, 3, 4"
3939
Optimization = "4"
4040
OptimizationOptimJL = "0.4"
41+
OptimizationMOI = "0.5"
4142
OrdinaryDiffEq = "6.33"
43+
PartialFunctions = "1"
4244
Plots = "1.40"
4345
RecursiveArrayTools = "3"
4446
SciMLBase = "2"

test/downstream/adjoints.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ gs_ts, = Zygote.gradient(sol) do sol
6868
sum(sum.(sol[[lorenz1.x, lorenz2.x], :]))
6969
end
7070

71-
@test_broken all(map(x -> x == true_grad_vecsym, gs_ts))
71+
@test all(map(x -> x == true_grad_vecsym, gs_ts))
7272

7373
# BatchedInterface AD
7474
@variables x(t)=1.0 y(t)=1.0 z(t)=1.0 w(t)=1.0

test/downstream/problem_interface.jl

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,3 +292,100 @@ prob = SteadyStateProblem(osys, u0, ps)
292292
getsym(prob, [:X, :X2])(prob) == [0.1, 0.2]
293293
@test getsym(prob, (X, X2))(prob) == getsym(prob, (osys.X, osys.X2))(prob) ==
294294
getsym(prob, (:X, :X2))(prob) == (0.1, 0.2)
295+
296+
@testset "SCCNonlinearProblem" begin
297+
# TODO: Rewrite this example when the MTK codegen is merged
298+
299+
function fullf!(du, u, p)
300+
du[1] = cos(u[2]) - u[1]
301+
du[2] = sin(u[1] + u[2]) + u[2]
302+
du[3] = 2u[4] + u[3] + p[1]
303+
du[4] = u[5]^2 + u[4]
304+
du[5] = u[3]^2 + u[5]
305+
du[6] = u[1] + u[2] + u[3] + u[4] + u[5] + 2.0u[6] + 2.5u[7] + 1.5u[8]
306+
du[7] = u[1] + u[2] + u[3] + 2.0u[4] + u[5] + 4.0u[6] - 1.5u[7] + 1.5u[8]
307+
du[8] = u[1] + 2.0u[2] + 3.0u[3] + 5.0u[4] + 6.0u[5] + u[6] - u[7] - u[8]
308+
end
309+
@variables u[1:8]=zeros(8) [irreducible = true]
310+
u2 = collect(u)
311+
@parameters p = 1.0
312+
eqs = Any[0 for _ in 1:8]
313+
fullf!(eqs, u, [p])
314+
@named model = NonlinearSystem(0 .~ eqs, [u...], [p])
315+
model = complete(model; split = false)
316+
317+
cache = zeros(4)
318+
cache[1] = 1.0
319+
320+
function f1!(du, u, p)
321+
du[1] = cos(u[2]) - u[1]
322+
du[2] = sin(u[1] + u[2]) + u[2]
323+
end
324+
explicitfun1(cache, sols) = nothing
325+
326+
f1!(eqs, u2[1:2], [p])
327+
@named subsys1 = NonlinearSystem(0 .~ eqs[1:2], [u2[1:2]...], [p])
328+
subsys1 = complete(subsys1; split = false)
329+
prob1 = NonlinearProblem(
330+
NonlinearFunction{true, SciMLBase.NoSpecialize}(f1!; sys = subsys1),
331+
zeros(2), copy(cache))
332+
333+
function f2!(du, u, p)
334+
du[1] = 2u[2] + u[1] + p[1]
335+
du[2] = u[3]^2 + u[2]
336+
du[3] = u[1]^2 + u[3]
337+
end
338+
explicitfun2(cache, sols) = nothing
339+
340+
f2!(eqs, u2[3:5], [p])
341+
@named subsys2 = NonlinearSystem(0 .~ eqs[1:3], [u2[3:5]...], [p])
342+
subsys2 = complete(subsys2; split = false)
343+
prob2 = NonlinearProblem(
344+
NonlinearFunction{true, SciMLBase.NoSpecialize}(f2!; sys = subsys2),
345+
zeros(3), copy(cache))
346+
347+
function f3!(du, u, p)
348+
du[1] = p[2] + 2.0u[1] + 2.5u[2] + 1.5u[3]
349+
du[2] = p[3] + 4.0u[1] - 1.5u[2] + 1.5u[3]
350+
du[3] = p[4] + +u[1] - u[2] - u[3]
351+
end
352+
function explicitfun3(cache, sols)
353+
cache[2] = sols[1][1] + sols[1][2] + sols[2][1] + sols[2][2] + sols[2][3]
354+
cache[3] = sols[1][1] + sols[1][2] + sols[2][1] + 2.0sols[2][2] + sols[2][3]
355+
cache[4] = sols[1][1] + 2.0sols[1][2] + 3.0sols[2][1] + 5.0sols[2][2] +
356+
6.0sols[2][3]
357+
end
358+
359+
@parameters tmpvar[1:3]
360+
f3!(eqs, u2[6:8], [p, tmpvar...])
361+
@named subsys3 = NonlinearSystem(0 .~ eqs[1:3], [u2[6:8]...], [p, tmpvar...])
362+
subsys3 = complete(subsys3; split = false)
363+
prob3 = NonlinearProblem(
364+
NonlinearFunction{true, SciMLBase.NoSpecialize}(f3!; sys = subsys3),
365+
zeros(3), copy(cache))
366+
367+
prob = NonlinearProblem(model, [])
368+
sccprob = SciMLBase.SCCNonlinearProblem([prob1, prob2, prob3],
369+
SciMLBase.Void{Any}.([explicitfun1, explicitfun2, explicitfun3]),
370+
model, copy(cache))
371+
372+
for sym in [u, u..., u[2] + u[3], p * u[1] + u[2]]
373+
@test prob[sym] sccprob[sym]
374+
end
375+
376+
for sym in [p, 2p + 1]
377+
@test prob.ps[sym] sccprob.ps[sym]
378+
end
379+
380+
for (i, sym) in enumerate([u[1], u[3], u[6]])
381+
sccprob[sym] = 0.5i
382+
@test sccprob[sym] 0.5i
383+
@test sccprob.probs[i].u0[1] 0.5i
384+
end
385+
sccprob.ps[p] = 2.5
386+
@test sccprob.ps[p] 2.5
387+
@test sccprob.parameter_object[1] 2.5
388+
for scc in sccprob.probs
389+
@test parameter_values(scc)[1] 2.5
390+
end
391+
end

0 commit comments

Comments
 (0)