Skip to content

Commit 541db50

Browse files
Support LinearProblems in SCCs
1 parent f7ee0ea commit 541db50

File tree

2 files changed

+45
-16
lines changed

2 files changed

+45
-16
lines changed

lib/SCCNonlinearSolve/src/SCCNonlinearSolve.jl

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,28 @@ import SciMLBase
44
import CommonSolve
55
import SymbolicIndexingInterface
66

7+
"""
8+
SCCAlg(; nlalg = nothing, linalg = nothing)
9+
10+
Algorithm for solving Strongly Connected Component (SCC) problems containing
11+
both nonlinear and linear subproblems.
12+
13+
### Keyword Arguments
14+
- `nlalg`: Algorithm to use for solving NonlinearProblem components
15+
- `linalg`: Algorithm to use for solving LinearProblem components
16+
"""
17+
struct SCCAlg{N, L}
18+
nlalg::N
19+
linalg::L
20+
end
21+
22+
SCCAlg(; nlalg = nothing, linalg = nothing) = SCCAlg(nlalg, linalg)
23+
724
function CommonSolve.solve(prob::SciMLBase.SCCNonlinearProblem; kwargs...)
8-
CommonSolve.solve(prob, nothing; kwargs...)
25+
CommonSolve.solve(prob, SCCAlg(nothing, nothing); kwargs...)
926
end
1027

11-
function CommonSolve.solve(prob::SciMLBase.SCCNonlinearProblem, alg; kwargs...)
28+
function CommonSolve.solve(prob::SciMLBase.SCCNonlinearProblem, alg::SCCAlg; kwargs...)
1229
numscc = length(prob.probs)
1330
sols = [SciMLBase.build_solution(
1431
prob, nothing, prob.u0, convert(eltype(prob.u0), NaN) * prob.u0)
@@ -20,9 +37,17 @@ function CommonSolve.solve(prob::SciMLBase.SCCNonlinearProblem, alg; kwargs...)
2037
for i in 1:numscc
2138
prob.explictfuns![i](
2239
SymbolicIndexingInterface.parameter_values(prob.probs[i]), sols)
23-
sol = SciMLBase.solve(prob.probs[i], alg; kwargs...)
24-
_sol = SciMLBase.build_solution(
25-
prob.probs[i], nothing, sol.u, sol.resid, retcode = sol.retcode)
40+
41+
if prob.probs[i] isa SciMLBase.LinearProblem
42+
sol = SciMLBase.solve(prob.probs[i], alg.linalg; kwargs...)
43+
_sol = SciMLBase.build_solution(
44+
prob.probs[i], nothing, sol.u, zero(sol.u), retcode = sol.retcode)
45+
else
46+
sol = SciMLBase.solve(prob.probs[i], alg.nlalg; kwargs...)
47+
_sol = SciMLBase.build_solution(
48+
prob.probs[i], nothing, sol.u, sol.resid, retcode = sol.retcode)
49+
end
50+
2651
sols[i] = _sol
2752
lasti = i
2853
if !SciMLBase.successful_retcode(_sol)
@@ -39,4 +64,5 @@ function CommonSolve.solve(prob::SciMLBase.SCCNonlinearProblem, alg; kwargs...)
3964
SciMLBase.build_solution(prob, alg, u, resid; retcode, original = sols)
4065
end
4166

67+
4268
end

lib/SCCNonlinearSolve/test/core_tests.jl

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,30 +41,33 @@ end
4141
NonlinearFunction{true, SciMLBase.NoSpecialize}(f2), zeros(3), p)
4242
sol2 = solve(prob2, NewtonRaphson())
4343

44-
function f3(du, u, p)
45-
du[1] = p[1] + 2.0u[1] + 2.5u[2] + 1.5u[3]
46-
du[2] = p[2] + 4.0u[1] - 1.5u[2] + 1.5u[3]
47-
du[3] = p[3] + +u[1] - u[2] - u[3]
48-
end
49-
prob3 = NonlinearProblem(
50-
NonlinearFunction{true, SciMLBase.NoSpecialize}(f3), zeros(3), p)
44+
# Convert f3 to a LinearProblem since it's linear in u
45+
# du = Au + b where A is the coefficient matrix and b is from parameters
46+
A3 = [2.0 2.5 1.5; 4.0 -1.5 1.5; 1.0 -1.0 -1.0]
47+
b3 = p # b will be updated by explicitfun3
48+
prob3 = LinearProblem(A3, b3, zeros(3))
5149
function explicitfun3(p, sols)
5250
p[1] = sols[1][1] + sols[1][2] + sols[2][1] + sols[2][2] + sols[2][3]
5351
p[2] = sols[1][1] + sols[1][2] + sols[2][1] + 2.0sols[2][2] + sols[2][3]
5452
p[3] = sols[1][1] + 2.0sols[1][2] + 3.0sols[2][1] + 5.0sols[2][2] +
5553
6.0sols[2][3]
5654
end
5755
explicitfun3(p, [sol1, sol2])
58-
sol3 = solve(prob3, NewtonRaphson())
56+
sol3 = solve(prob3) # LinearProblem uses default linear solver
5957
manualscc = [sol1; sol2; sol3]
6058

6159
sccprob = SciMLBase.SCCNonlinearProblem([prob1, prob2, prob3],
6260
SciMLBase.Void{Any}.([explicitfun1, explicitfun2, explicitfun3]))
63-
scc_sol = solve(sccprob, NewtonRaphson())
61+
62+
# Test with SCCAlg that handles both nonlinear and linear problems
63+
using SCCNonlinearSolve
64+
scc_alg = SCCNonlinearSolve.SCCAlg(nlalg = NewtonRaphson(), linalg = nothing)
65+
scc_sol = solve(sccprob, scc_alg)
6466
@test sol manualscc scc_sol
6567

6668
import NonlinearSolve # Required for Default
6769

68-
scc_sol = solve(sccprob)
69-
@test sol manualscc scc_sol
70+
# Test default interface
71+
scc_sol_default = solve(sccprob)
72+
@test sol manualscc scc_sol_default
7073
end

0 commit comments

Comments
 (0)