Skip to content

Commit bbc413b

Browse files
Merge pull request #877 from AayushSabharwal/as/scc-nlprob-sii
feat: implement SII for `SCCNonlinearProblem`
2 parents e3846c8 + 069ffd2 commit bbc413b

File tree

3 files changed

+157
-3
lines changed

3 files changed

+157
-3
lines changed

src/problems/nonlinear_problems.jl

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,64 @@ Note that this example aliases the parameters together for a memory-reduced repr
462462
* `probs`: the collection of problems to solve
463463
* `explictfuns!`: the explicit functions for mutating the parameter set
464464
"""
465-
mutable struct SCCNonlinearProblem{P, E}
465+
mutable struct SCCNonlinearProblem{uType, iip, P, E, I, Par} <:
466+
AbstractNonlinearProblem{uType, iip}
466467
probs::P
467-
explictfuns!::E
468+
explicitfuns!::E
469+
full_index_provider::I
470+
parameter_object::Par
471+
parameters_alias::Bool
472+
473+
function SCCNonlinearProblem{P, E, I, Par}(
474+
probs::P, funs::E, indp::I, pobj::Par, alias::Bool) where {P, E, I, Par}
475+
u0 = mapreduce(state_values, vcat, probs)
476+
uType = typeof(u0)
477+
new{uType, false, P, E, I, Par}(probs, funs, indp, pobj, alias)
478+
end
479+
end
480+
481+
function SCCNonlinearProblem(probs, explicitfuns!, full_index_provider = nothing,
482+
parameter_object = nothing, parameters_alias = false)
483+
return SCCNonlinearProblem{typeof(probs), typeof(explicitfuns!),
484+
typeof(full_index_provider), typeof(parameter_object)}(
485+
probs, explicitfuns!, full_index_provider, parameter_object, parameters_alias)
486+
end
487+
488+
function Base.getproperty(prob::SCCNonlinearProblem, name::Symbol)
489+
if name == :explictfuns!
490+
return getfield(prob, :explicitfuns!)
491+
elseif name == :ps
492+
return ParameterIndexingProxy(prob)
493+
end
494+
return getfield(prob, name)
495+
end
496+
497+
function SymbolicIndexingInterface.symbolic_container(prob::SCCNonlinearProblem)
498+
prob.full_index_provider
499+
end
500+
function SymbolicIndexingInterface.parameter_values(prob::SCCNonlinearProblem)
501+
prob.parameter_object
502+
end
503+
function SymbolicIndexingInterface.state_values(prob::SCCNonlinearProblem)
504+
mapreduce(state_values, vcat, prob.probs)
505+
end
506+
507+
function SymbolicIndexingInterface.set_state!(prob::SCCNonlinearProblem, val, idx)
508+
for scc in prob.probs
509+
svals = state_values(scc)
510+
checkbounds(Bool, svals, idx) && return set_state!(scc, val, idx)
511+
idx -= length(svals)
512+
end
513+
throw(BoundsError(state_values(prob), idx))
514+
end
515+
516+
function SymbolicIndexingInterface.set_parameter!(prob::SCCNonlinearProblem, val, idx)
517+
if prob.parameter_object !== nothing
518+
set_parameter!(prob.parameter_object, val, idx)
519+
prob.parameters_alias && return
520+
end
521+
for scc in prob.probs
522+
is_parameter(scc, idx) || continue
523+
set_parameter!(scc, val, idx)
524+
end
468525
end

src/problems/problem_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ function Base.show(io::IO, mime::MIME"text/plain", A::AbstractNonlinearProblem)
6060
summary(io, A)
6161
println(io)
6262
print(io, "u0: ")
63-
show(io, mime, A.u0)
63+
show(io, mime, state_values(A))
6464
end
6565

6666
function Base.show(io::IO, mime::MIME"text/plain", A::IntervalNonlinearProblem)

test/downstream/problem_interface.jl

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,3 +292,100 @@ prob = SteadyStateProblem(osys, u0, ps)
292292
getsym(prob, [:X, :X2])(prob) == [0.1, 0.2]
293293
@test getsym(prob, (X, X2))(prob) == getsym(prob, (osys.X, osys.X2))(prob) ==
294294
getsym(prob, (:X, :X2))(prob) == (0.1, 0.2)
295+
296+
@testset "SCCNonlinearProblem" begin
297+
# TODO: Rewrite this example when the MTK codegen is merged
298+
299+
function fullf!(du, u, p)
300+
du[1] = cos(u[2]) - u[1]
301+
du[2] = sin(u[1] + u[2]) + u[2]
302+
du[3] = 2u[4] + u[3] + p[1]
303+
du[4] = u[5]^2 + u[4]
304+
du[5] = u[3]^2 + u[5]
305+
du[6] = u[1] + u[2] + u[3] + u[4] + u[5] + 2.0u[6] + 2.5u[7] + 1.5u[8]
306+
du[7] = u[1] + u[2] + u[3] + 2.0u[4] + u[5] + 4.0u[6] - 1.5u[7] + 1.5u[8]
307+
du[8] = u[1] + 2.0u[2] + 3.0u[3] + 5.0u[4] + 6.0u[5] + u[6] - u[7] - u[8]
308+
end
309+
@variables u[1:8]=zeros(8) [irreducible = true]
310+
u2 = collect(u)
311+
@parameters p = 1.0
312+
eqs = Any[0 for _ in 1:8]
313+
fullf!(eqs, u, [p])
314+
@named model = NonlinearSystem(0 .~ eqs, [u...], [p])
315+
model = complete(model; split = false)
316+
317+
cache = zeros(4)
318+
cache[1] = 1.0
319+
320+
function f1!(du, u, p)
321+
du[1] = cos(u[2]) - u[1]
322+
du[2] = sin(u[1] + u[2]) + u[2]
323+
end
324+
explicitfun1(cache, sols) = nothing
325+
326+
f1!(eqs, u2[1:2], [p])
327+
@named subsys1 = NonlinearSystem(0 .~ eqs[1:2], [u2[1:2]...], [p])
328+
subsys1 = complete(subsys1; split = false)
329+
prob1 = NonlinearProblem(
330+
NonlinearFunction{true, SciMLBase.NoSpecialize}(f1!; sys = subsys1),
331+
zeros(2), copy(cache))
332+
333+
function f2!(du, u, p)
334+
du[1] = 2u[2] + u[1] + p[1]
335+
du[2] = u[3]^2 + u[2]
336+
du[3] = u[1]^2 + u[3]
337+
end
338+
explicitfun2(cache, sols) = nothing
339+
340+
f2!(eqs, u2[3:5], [p])
341+
@named subsys2 = NonlinearSystem(0 .~ eqs[1:3], [u2[3:5]...], [p])
342+
subsys2 = complete(subsys2; split = false)
343+
prob2 = NonlinearProblem(
344+
NonlinearFunction{true, SciMLBase.NoSpecialize}(f2!; sys = subsys2),
345+
zeros(3), copy(cache))
346+
347+
function f3!(du, u, p)
348+
du[1] = p[2] + 2.0u[1] + 2.5u[2] + 1.5u[3]
349+
du[2] = p[3] + 4.0u[1] - 1.5u[2] + 1.5u[3]
350+
du[3] = p[4] + +u[1] - u[2] - u[3]
351+
end
352+
function explicitfun3(cache, sols)
353+
cache[2] = sols[1][1] + sols[1][2] + sols[2][1] + sols[2][2] + sols[2][3]
354+
cache[3] = sols[1][1] + sols[1][2] + sols[2][1] + 2.0sols[2][2] + sols[2][3]
355+
cache[4] = sols[1][1] + 2.0sols[1][2] + 3.0sols[2][1] + 5.0sols[2][2] +
356+
6.0sols[2][3]
357+
end
358+
359+
@parameters tmpvar[1:3]
360+
f3!(eqs, u2[6:8], [p, tmpvar...])
361+
@named subsys3 = NonlinearSystem(0 .~ eqs[1:3], [u2[6:8]...], [p, tmpvar...])
362+
subsys3 = complete(subsys3; split = false)
363+
prob3 = NonlinearProblem(
364+
NonlinearFunction{true, SciMLBase.NoSpecialize}(f3!; sys = subsys3),
365+
zeros(3), copy(cache))
366+
367+
prob = NonlinearProblem(model, [])
368+
sccprob = SciMLBase.SCCNonlinearProblem([prob1, prob2, prob3],
369+
SciMLBase.Void{Any}.([explicitfun1, explicitfun2, explicitfun3]),
370+
model, copy(cache))
371+
372+
for sym in [u, u..., u[2] + u[3], p * u[1] + u[2]]
373+
@test prob[sym] sccprob[sym]
374+
end
375+
376+
for sym in [p, 2p + 1]
377+
@test prob.ps[sym] sccprob.ps[sym]
378+
end
379+
380+
for (i, sym) in enumerate([u[1], u[3], u[6]])
381+
sccprob[sym] = 0.5i
382+
@test sccprob[sym] 0.5i
383+
@test sccprob.probs[i].u0[1] 0.5i
384+
end
385+
sccprob.ps[p] = 2.5
386+
@test sccprob.ps[p] 2.5
387+
@test sccprob.parameter_object[1] 2.5
388+
for scc in sccprob.probs
389+
@test parameter_values(scc)[1] 2.5
390+
end
391+
end

0 commit comments

Comments
 (0)