Skip to content

Commit 005a470

Browse files
fix: improve type-stability of SCCNonlinearProblem
1 parent 0c9593b commit 005a470

File tree

3 files changed

+58
-25
lines changed

3 files changed

+58
-25
lines changed

src/problems/nonlinear_problems.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -484,33 +484,37 @@ Note that this example aliases the parameters together for a memory-reduced repr
484484
* `probs`: the collection of problems to solve
485485
* `explictfuns!`: the explicit functions for mutating the parameter set
486486
"""
487-
mutable struct SCCNonlinearProblem{uType, iip, P, E, F <: NonlinearFunction{iip}, Par} <:
487+
mutable struct SCCNonlinearProblem{uType, iip, P, E, F <: NonlinearFunction{iip},
488+
Par, Palias <: Union{Val{true}, Val{false}}} <:
488489
AbstractNonlinearProblem{uType, iip}
489490
probs::P
490491
explicitfuns!::E
491492
# NonlinearFunction with `f = Returns(nothing)`
492493
f::F
493494
p::Par
494-
parameters_alias::Bool
495+
parameters_alias::Palias
495496

496497
function SCCNonlinearProblem{P, E, F, Par}(probs::P, funs::E, f::F, pobj::Par,
497-
alias::Bool) where {P, E, F <: NonlinearFunction, Par}
498+
alias::Palias) where {P, E, F <: NonlinearFunction, Par, Palias}
498499
init = state_values(first(probs))
499500
if ArrayInterface.ismutable(init)
500-
init = similar(init)
501+
init = similar(init, 0)
501502
else
502503
init = StaticArraysCore.similar_type(init, StaticArraysCore.Size(0))()
503504
end
504505
u0 = mapreduce(
505506
state_values, vcat, probs; init = init)
506507
uType = typeof(u0)
507-
new{uType, false, P, E, F, Par}(probs, funs, f, pobj, alias)
508+
new{uType, false, P, E, F, Par, Palias}(probs, funs, f, pobj, alias)
508509
end
509510
end
510511

511512
function SCCNonlinearProblem(probs, explicitfuns!, parameter_object = nothing,
512-
parameters_alias = false; kwargs...)
513+
parameters_alias::Union{Bool, Val{true}, Val{false}} = Val(false); kwargs...)
513514
f = NonlinearFunction{false}(Returns(nothing); kwargs...)
515+
if parameters_alias isa Bool
516+
parameters_alias = Val(parameters_alias)
517+
end
514518
return SCCNonlinearProblem{typeof(probs), typeof(explicitfuns!),
515519
typeof(f), typeof(parameter_object)}(
516520
probs, explicitfuns!, f, parameter_object, parameters_alias)
@@ -534,7 +538,7 @@ end
534538
function SymbolicIndexingInterface.state_values(prob::SCCNonlinearProblem)
535539
init = state_values(first(prob.probs))
536540
if ArrayInterface.ismutable(init)
537-
init = similar(init)
541+
init = similar(init, 0)
538542
else
539543
init = StaticArraysCore.similar_type(init, StaticArraysCore.Size(0))()
540544
end

src/remake.jl

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,49 @@ function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p
803803
return prob
804804
end
805805

806+
function scc_update_subproblems(probs::Vector, newu0, newp, parameters_alias)
807+
offset = Ref(0)
808+
return map(probs) do subprob
809+
# N should be inferred if `prob` is type-stable and `subprob.u0 isa StaticArray`
810+
N = length(state_values(subprob))
811+
if ArrayInterface.ismutable(newu0)
812+
_u0 = newu0[(offset[] + 1):(offset[] + N)]
813+
else
814+
_u0 = StaticArraysCore.similar_type(
815+
newu0, StaticArraysCore.Size(N))(newu0[(offset[] + 1):(offset[] + N)])
816+
end
817+
subprob = if parameters_alias === Val(true)
818+
remake(subprob; u0 = _u0, p = newp)
819+
else
820+
remake(subprob; u0 = _u0)
821+
end
822+
offset[] += length(state_values(subprob))
823+
return subprob
824+
end
825+
end
826+
827+
function scc_update_subproblems(probs::Tuple, newu0, newp, parameters_alias)
828+
offset = Ref(0)
829+
return ntuple(Val(length(probs))) do i
830+
subprob = probs[i]
831+
# N should be inferred if `prob` is type-stable and `subprob.u0 isa StaticArray`
832+
N = length(state_values(subprob))
833+
if ArrayInterface.ismutable(newu0)
834+
_u0 = newu0[(offset[] + 1):(offset[] + N)]
835+
else
836+
_u0 = StaticArraysCore.similar_type(
837+
newu0, StaticArraysCore.Size(N))(newu0[(offset[] + 1):(offset[] + N)])
838+
end
839+
subprob = if parameters_alias === Val(true)
840+
remake(subprob; u0 = _u0, p = newp)
841+
else
842+
remake(subprob; u0 = _u0)
843+
end
844+
offset[] += N
845+
return subprob
846+
end
847+
end
848+
806849
"""
807850
remake(prob::SCCNonlinearProblem; u0 = missing, p = missing, probs = missing,
808851
parameters_alias = prob.parameters_alias, sys = missing, explicitfuns! = missing)
@@ -831,23 +874,8 @@ function remake(prob::SCCNonlinearProblem; u0 = missing, p = missing, probs = mi
831874
if sys === missing
832875
sys = prob.f.sys
833876
end
834-
offset = 0
835877
if u0 !== missing || p !== missing && parameters_alias
836-
probs = map(probs) do subprob
837-
_u0 = newu0[(offset + 1):(offset + length(state_values(subprob)))]
838-
if !ArrayInterface.ismutable(newu0)
839-
_u0 = StaticArraysCore.similar_type(
840-
newu0, StaticArraysCore.Size(length(_u0)))(_u0)
841-
end
842-
subprob = if parameters_alias
843-
remake(subprob; u0 = _u0, p = newp)
844-
else
845-
remake(subprob;
846-
u0 = _u0)
847-
end
848-
offset += length(state_values(subprob))
849-
return subprob
850-
end
878+
probs = scc_update_subproblems(probs, newu0, newp, parameters_alias)
851879
end
852880
f = coalesce(f, prob.f)
853881
f = remake(f; sys)

test/problem_building_test.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,15 @@ end
130130
prob1 = NonlinearProblem(f1, SA[1.0], p)
131131
prob2 = NonlinearProblem(f2, SA[0.0], p)
132132
prob3 = NonlinearProblem(f3, SA[1.0], p)
133-
sccprob = SCCNonlinearProblem([prob1, prob2, prob3], [explicit1, explicit2, explicit3], p, true)
133+
sccprob = SCCNonlinearProblem(
134+
[prob1, prob2, prob3], [explicit1, explicit2, explicit3], p, true)
134135

135136
@test !SciMLBase.isinplace(sccprob)
136137
@test sccprob isa SCCNonlinearProblem{SVector{3, Float64}}
137138
@test state_values(sccprob) isa SVector{3, Float64}
138139
@test sccprob.p === prob1.p === prob2.p === prob3.p
139140

140-
sccprob2 = remake(sccprob; u0 = SA[2.0, 1.0, 2.0])
141+
sccprob2 = @inferred remake(sccprob; u0 = SA[2.0, 1.0, 2.0])
141142
@test !SciMLBase.isinplace(sccprob2)
142143
@test sccprob2 isa SCCNonlinearProblem{SVector{3, Float64}}
143144
@test state_values(sccprob2) isa SVector{3, Float64}

0 commit comments

Comments
 (0)