Skip to content

Commit b635875

Browse files
fix tests
1 parent 541db50 commit b635875

File tree

2 files changed

+38
-36
lines changed

2 files changed

+38
-36
lines changed

lib/SCCNonlinearSolve/src/SCCNonlinearSolve.jl

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module SCCNonlinearSolve
33
import SciMLBase
44
import CommonSolve
55
import SymbolicIndexingInterface
6+
import SciMLBase: NonlinearProblem, NonlinearLeastSquaresProblem, LinearProblem
67

78
"""
89
SCCAlg(; nlalg = nothing, linalg = nothing)
@@ -25,41 +26,42 @@ function CommonSolve.solve(prob::SciMLBase.SCCNonlinearProblem; kwargs...)
2526
CommonSolve.solve(prob, SCCAlg(nothing, nothing); kwargs...)
2627
end
2728

29+
probvec(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}) = prob.u0
30+
probvec(prob::LinearProblem) = prob.b
31+
32+
iteratively_build_sols(alg, sols; kwargs...) = sols
33+
34+
function iteratively_build_sols(alg, sols, (prob, explicitfun), args...; kwargs...)
35+
explicitfun(
36+
SymbolicIndexingInterface.parameter_values(prob), sols)
37+
38+
_sol = if prob isa SciMLBase.LinearProblem
39+
sol = SciMLBase.solve(prob, alg.linalg; kwargs...)
40+
SciMLBase.build_linear_solution(
41+
alg.linalg, sol.u, nothing, nothing, retcode = sol.retcode)
42+
else
43+
sol = SciMLBase.solve(prob, alg.nlalg; kwargs...)
44+
SciMLBase.build_solution(
45+
prob, nothing, sol.u, sol.resid, retcode = sol.retcode)
46+
end
47+
48+
iteratively_build_sols(alg, (sols..., _sol), args...)
49+
end
50+
2851
function CommonSolve.solve(prob::SciMLBase.SCCNonlinearProblem, alg::SCCAlg; kwargs...)
2952
numscc = length(prob.probs)
30-
sols = [SciMLBase.build_solution(
31-
prob, nothing, prob.u0, convert(eltype(prob.u0), NaN) * prob.u0)
32-
for prob in prob.probs]
33-
u = reduce(vcat, [prob.u0 for prob in prob.probs])
34-
resid = copy(u)
35-
36-
lasti = 1
37-
for i in 1:numscc
38-
prob.explictfuns![i](
39-
SymbolicIndexingInterface.parameter_values(prob.probs[i]), sols)
40-
41-
if prob.probs[i] isa SciMLBase.LinearProblem
42-
sol = SciMLBase.solve(prob.probs[i], alg.linalg; kwargs...)
43-
_sol = SciMLBase.build_solution(
44-
prob.probs[i], nothing, sol.u, zero(sol.u), retcode = sol.retcode)
45-
else
46-
sol = SciMLBase.solve(prob.probs[i], alg.nlalg; kwargs...)
47-
_sol = SciMLBase.build_solution(
48-
prob.probs[i], nothing, sol.u, sol.resid, retcode = sol.retcode)
49-
end
50-
51-
sols[i] = _sol
52-
lasti = i
53-
if !SciMLBase.successful_retcode(_sol)
54-
break
55-
end
56-
end
53+
sols = iteratively_build_sols(alg, (), zip(prob.probs, prob.explicitfuns!)...; kwargs...)
5754

5855
# TODO: fix allocations with a lazy concatenation
59-
u .= reduce(vcat, sols)
60-
resid .= reduce(vcat, getproperty.(sols, :resid))
56+
u = reduce(vcat, sols)
57+
resid = reduce(vcat, getproperty.(sols, :resid))
6158

62-
retcode = sols[lasti].retcode
59+
retcode = if !all(SciMLBase.successful_retcode, sols)
60+
idx = findfirst(!SciMLBase.successful_retcode, sols)
61+
sols[idx].retcode
62+
else
63+
SciMLBase.ReturnCode.Success
64+
end
6365

6466
SciMLBase.build_solution(prob, alg, u, resid; retcode, original = sols)
6567
end

lib/SCCNonlinearSolve/test/core_tests.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,16 @@ end
4747
b3 = p # b will be updated by explicitfun3
4848
prob3 = LinearProblem(A3, b3, zeros(3))
4949
function explicitfun3(p, sols)
50-
p[1] = sols[1][1] + sols[1][2] + sols[2][1] + sols[2][2] + sols[2][3]
51-
p[2] = sols[1][1] + sols[1][2] + sols[2][1] + 2.0sols[2][2] + sols[2][3]
52-
p[3] = sols[1][1] + 2.0sols[1][2] + 3.0sols[2][1] + 5.0sols[2][2] +
53-
6.0sols[2][3]
50+
p[1] = -(sols[1][1] + sols[1][2] + sols[2][1] + sols[2][2] + sols[2][3])
51+
p[2] = -(sols[1][1] + sols[1][2] + sols[2][1] + 2.0sols[2][2] + sols[2][3])
52+
p[3] = -(sols[1][1] + 2.0sols[1][2] + 3.0sols[2][1] + 5.0sols[2][2] +
53+
6.0sols[2][3])
5454
end
5555
explicitfun3(p, [sol1, sol2])
5656
sol3 = solve(prob3) # LinearProblem uses default linear solver
57-
manualscc = [sol1; sol2; sol3]
57+
manualscc = reduce(vcat,(sol1, sol2, sol3))
5858

59-
sccprob = SciMLBase.SCCNonlinearProblem([prob1, prob2, prob3],
59+
sccprob = SciMLBase.SCCNonlinearProblem((prob1, prob2, prob3),
6060
SciMLBase.Void{Any}.([explicitfun1, explicitfun2, explicitfun3]))
6161

6262
# Test with SCCAlg that handles both nonlinear and linear problems

0 commit comments

Comments
 (0)