@@ -3,6 +3,7 @@ module SCCNonlinearSolve
3
3
import SciMLBase
4
4
import CommonSolve
5
5
import SymbolicIndexingInterface
6
+ import SciMLBase: NonlinearProblem, NonlinearLeastSquaresProblem, LinearProblem
6
7
7
8
"""
8
9
SCCAlg(; nlalg = nothing, linalg = nothing)
@@ -25,41 +26,42 @@ function CommonSolve.solve(prob::SciMLBase.SCCNonlinearProblem; kwargs...)
25
26
CommonSolve. solve (prob, SCCAlg (nothing , nothing ); kwargs... )
26
27
end
27
28
29
+ probvec (prob:: Union{NonlinearProblem, NonlinearLeastSquaresProblem} ) = prob. u0
30
+ probvec (prob:: LinearProblem ) = prob. b
31
+
32
+ iteratively_build_sols (alg, sols; kwargs... ) = sols
33
+
34
+ function iteratively_build_sols (alg, sols, (prob, explicitfun), args... ; kwargs... )
35
+ explicitfun (
36
+ SymbolicIndexingInterface. parameter_values (prob), sols)
37
+
38
+ _sol = if prob isa SciMLBase. LinearProblem
39
+ sol = SciMLBase. solve (prob, alg. linalg; kwargs... )
40
+ SciMLBase. build_linear_solution (
41
+ alg. linalg, sol. u, nothing , nothing , retcode = sol. retcode)
42
+ else
43
+ sol = SciMLBase. solve (prob, alg. nlalg; kwargs... )
44
+ SciMLBase. build_solution (
45
+ prob, nothing , sol. u, sol. resid, retcode = sol. retcode)
46
+ end
47
+
48
+ iteratively_build_sols (alg, (sols... , _sol), args... )
49
+ end
50
+
28
51
function CommonSolve. solve (prob:: SciMLBase.SCCNonlinearProblem , alg:: SCCAlg ; kwargs... )
29
52
numscc = length (prob. probs)
30
- sols = [SciMLBase. build_solution (
31
- prob, nothing , prob. u0, convert (eltype (prob. u0), NaN ) * prob. u0)
32
- for prob in prob. probs]
33
- u = reduce (vcat, [prob. u0 for prob in prob. probs])
34
- resid = copy (u)
35
-
36
- lasti = 1
37
- for i in 1 : numscc
38
- prob. explictfuns, sols)
40
-
41
- if prob. probs[i] isa SciMLBase. LinearProblem
42
- sol = SciMLBase. solve (prob. probs[i], alg. linalg; kwargs... )
43
- _sol = SciMLBase. build_solution (
44
- prob. probs[i], nothing , sol. u, zero (sol. u), retcode = sol. retcode)
45
- else
46
- sol = SciMLBase. solve (prob. probs[i], alg. nlalg; kwargs... )
47
- _sol = SciMLBase. build_solution (
48
- prob. probs[i], nothing , sol. u, sol. resid, retcode = sol. retcode)
49
- end
50
-
51
- sols[i] = _sol
52
- lasti = i
53
- if ! SciMLBase. successful_retcode (_sol)
54
- break
55
- end
56
- end
53
+ sols = iteratively_build_sols (alg, (), zip (prob. probs, prob. explicitfuns!)... ; kwargs... )
57
54
58
55
# TODO : fix allocations with a lazy concatenation
59
- u . = reduce (vcat, sols)
60
- resid . = reduce (vcat, getproperty .(sols, :resid ))
56
+ u = reduce (vcat, sols)
57
+ resid = reduce (vcat, getproperty .(sols, :resid ))
61
58
62
- retcode = sols[lasti]. retcode
59
+ retcode = if ! all (SciMLBase. successful_retcode, sols)
60
+ idx = findfirst (! SciMLBase. successful_retcode, sols)
61
+ sols[idx]. retcode
62
+ else
63
+ SciMLBase. ReturnCode. Success
64
+ end
63
65
64
66
SciMLBase. build_solution (prob, alg, u, resid; retcode, original = sols)
65
67
end
0 commit comments