Skip to content

Commit 7df67f7

Browse files
fix: retain type of u0 in SCCNonlinearProblem impl for state_values and remake
1 parent eccc470 commit 7df67f7

File tree

4 files changed

+67
-8
lines changed

4 files changed

+67
-8
lines changed

src/SciMLBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import ADTypes: ADTypes, AbstractADType
2525
import Accessors: @set, @reset, @delete, @insert
2626
using Moshi.Data: @data
2727
using Moshi.Match: @match
28+
import StaticArraysCore
2829

2930
using Reexport
3031
using SciMLOperators

src/problems/nonlinear_problems.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -495,8 +495,14 @@ mutable struct SCCNonlinearProblem{uType, iip, P, E, F <: NonlinearFunction{iip}
495495

496496
function SCCNonlinearProblem{P, E, F, Par}(probs::P, funs::E, f::F, pobj::Par,
497497
alias::Bool) where {P, E, F <: NonlinearFunction, Par}
498+
init = state_values(first(probs))
499+
if ArrayInterface.ismutable(init)
500+
init = similar(init)
501+
else
502+
init = StaticArraysCore.similar_type(init, StaticArraysCore.Size(0))()
503+
end
498504
u0 = mapreduce(
499-
state_values, vcat, probs; init = similar(state_values(first(probs)), 0))
505+
state_values, vcat, probs; init = init)
500506
uType = typeof(u0)
501507
new{uType, false, P, E, F, Par}(probs, funs, f, pobj, alias)
502508
end
@@ -526,8 +532,13 @@ function SymbolicIndexingInterface.parameter_values(prob::SCCNonlinearProblem)
526532
prob.p
527533
end
528534
function SymbolicIndexingInterface.state_values(prob::SCCNonlinearProblem)
529-
mapreduce(
530-
state_values, vcat, prob.probs; init = similar(state_values(first(prob.probs)), 0))
535+
init = state_values(first(prob.probs))
536+
if ArrayInterface.ismutable(init)
537+
init = similar(init)
538+
else
539+
init = StaticArraysCore.similar_type(init, StaticArraysCore.Size(0))()
540+
end
541+
mapreduce(state_values, vcat, prob.probs; init)
531542
end
532543

533544
function SymbolicIndexingInterface.set_state!(prob::SCCNonlinearProblem, val, idx)

src/remake.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -834,13 +834,16 @@ function remake(prob::SCCNonlinearProblem; u0 = missing, p = missing, probs = mi
834834
offset = 0
835835
if u0 !== missing || p !== missing && parameters_alias
836836
probs = map(probs) do subprob
837+
_u0 = newu0[(offset + 1):(offset + length(state_values(subprob)))]
838+
if !ArrayInterface.ismutable(newu0)
839+
_u0 = StaticArraysCore.similar_type(
840+
newu0, StaticArraysCore.Size(length(_u0)))(_u0)
841+
end
837842
subprob = if parameters_alias
838-
remake(subprob;
839-
u0 = newu0[(offset + 1):(offset + length(state_values(subprob)))],
840-
p = newp)
843+
remake(subprob; u0 = _u0, p = newp)
841844
else
842845
remake(subprob;
843-
u0 = newu0[(offset + 1):(offset + length(state_values(subprob)))])
846+
u0 = _u0)
844847
end
845848
offset += length(state_values(subprob))
846849
return subprob

test/problem_building_test.jl

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Test, SciMLBase, SymbolicIndexingInterface, Accessors
1+
using Test, SciMLBase, SymbolicIndexingInterface, Accessors, StaticArrays
22

33
function simplependulum!(du, u, p, t)
44
θ = u[1]
@@ -98,3 +98,47 @@ let
9898
@test dprob8.u0 === u₀
9999
@test dprob8.tspan === tspan
100100
end
101+
102+
@testset "SCCNonlinearProblem with static arrays" begin
103+
function f1(u, p)
104+
y = u[1]
105+
x = p[1]
106+
return SA[1 - y^2 - x^2]
107+
end
108+
109+
function f2(u, p)
110+
yt = u[1]
111+
x, xt, y = p
112+
return SA[-2y * yt - 2x * xt]
113+
end
114+
115+
function f3(u, p)
116+
lam = u[1]
117+
x, xt, y, yt = p
118+
return SA[-2xt^2 - 2yt^2 - 2y * (-1 + y * lam) - 2x^2 * lam]
119+
end
120+
121+
explicit1 = Returns(nothing)
122+
function explicit2(p, sols)
123+
p[3] = sols[1].u[1]
124+
end
125+
function explicit3(p, sols)
126+
p[4] = sols[2].u[1]
127+
end
128+
129+
p = [1.0, 0.0, NaN, NaN]
130+
prob1 = NonlinearProblem(f1, SA[1.0], p)
131+
prob2 = NonlinearProblem(f2, SA[0.0], p)
132+
prob3 = NonlinearProblem(f3, SA[1.0], p)
133+
sccprob = SCCNonlinearProblem([prob1, prob2, prob3], [explicit1, explicit2, explicit3], p, true)
134+
135+
@test !SciMLBase.isinplace(sccprob)
136+
@test sccprob isa SCCNonlinearProblem{SVector{3, Float64}}
137+
@test state_values(sccprob) isa SVector{3, Float64}
138+
@test sccprob.p === prob1.p === prob2.p === prob3.p
139+
140+
sccprob2 = remake(sccprob; u0 = SA[2.0, 1.0, 2.0])
141+
@test !SciMLBase.isinplace(sccprob2)
142+
@test sccprob2 isa SCCNonlinearProblem{SVector{3, Float64}}
143+
@test state_values(sccprob2) isa SVector{3, Float64}
144+
end

0 commit comments

Comments
 (0)