Skip to content

Commit 6084a88

Browse files
committed
Move around the tests a bit
1 parent 1e27dd6 commit 6084a88

29 files changed

+325
-467
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ jobs:
1818
fail-fast: false
1919
matrix:
2020
group:
21-
- Core
22-
- NLLS
21+
- RootFinding
22+
- NLLSSolvers
2323
- 23TestProblems
2424
- Wrappers
2525
- Miscellaneous

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ SparseDiffTools = "2.14"
9797
SpeedMapping = "0.3"
9898
StableRNGs = "1"
9999
StaticArrays = "1.7"
100+
Sundials = "4.23.1"
100101
Symbolics = "5.13"
101102
Test = "1"
102103
UnPack = "1.0"
@@ -128,9 +129,10 @@ SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
128129
SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
129130
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
130131
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
132+
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
131133
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
132134
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
133135
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
134136

135137
[targets]
136-
test = ["Aqua", "Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "StableRNGs", "MINPACK", "NLsolve", "OrdinaryDiffEq", "SpeedMapping", "FixedPointAcceleration", "SIAMFANLEquations"]
138+
test = ["Aqua", "Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "StableRNGs", "MINPACK", "NLsolve", "OrdinaryDiffEq", "SpeedMapping", "FixedPointAcceleration", "SIAMFANLEquations", "Sundials"]

docs/src/solvers/FixedPointSolvers.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,10 @@ We are only listing the methods that natively solve fixed point problems.
4040

4141
- `FixedPointAccelerationJL()`: accelerates the convergence of a mapping to a fixed point
4242
by the Anderson acceleration algorithm and a few other methods.
43+
44+
### NLsolve.jl
45+
46+
In our tests, we have found the anderson method implemented here to NOT be the most
47+
robust.
48+
49+
- `NLsolveJL(; method = :anderson)`: Anderson acceleration for fixed point problems.

ext/NonlinearSolveFastLevenbergMarquardtExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import FiniteDiff, ForwardDiff
1515
end
1616
end
1717

18+
# TODO: Implement reinit
1819
@concrete struct FastLevenbergMarquardtJLCache
1920
f!
2021
J!

ext/NonlinearSolveFixedPointAccelerationExt.jl

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,41 +9,25 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::FixedPointAccelerationJL
99
@assert (termination_condition ===
1010
nothing)||(termination_condition isa AbsNormTerminationMode) "FixedPointAccelerationJL does not support termination conditions!"
1111

12-
u0 = NonlinearSolve.__maybe_unaliased(prob.u0, alias_u0)
13-
u_size = size(u0)
14-
T = eltype(u0)
15-
iip = isinplace(prob)
16-
p = prob.p
17-
18-
if !iip && prob.u0 isa Number
19-
# FixedPointAcceleration makes the scalar problem into a vector problem
20-
f = (u) -> [prob.f(u[1], p) .+ u[1]]
21-
elseif !iip && prob.u0 isa AbstractVector
22-
f = (u) -> (prob.f(u, p) .+ u)
23-
elseif !iip && prob.u0 isa AbstractArray
24-
f = (u) -> vec(prob.f(reshape(u, u_size), p) .+ u)
25-
elseif iip && prob.u0 isa AbstractVector
26-
du = similar(u0)
27-
f = (u) -> (prob.f(du, u, p); du .+ u)
28-
else
29-
du = similar(u0)
30-
f = (u) -> (prob.f(du, reshape(u, u_size), p); vec(du) .+ u)
31-
end
12+
f, u0 = NonlinearSolve.__construct_f(prob; alias_u0, make_fixed_point = Val(true),
13+
force_oop = Val(true))
3214

3315
tol = NonlinearSolve.DEFAULT_TOLERANCE(abstol, eltype(u0))
3416

35-
sol = fixed_point(f, NonlinearSolve._vec(u0); Algorithm = alg.algorithm,
17+
sol = fixed_point(f, u0; Algorithm = alg.algorithm,
3618
ConvergenceMetricThreshold = tol, MaxIter = maxiters, MaxM = alg.m,
3719
ExtrapolationPeriod = alg.extrapolation_period, Dampening = alg.dampening,
3820
PrintReports, ReplaceInvalids = alg.replace_invalids,
3921
ConditionNumberThreshold = alg.condition_number_threshold, quiet_errors = true)
4022

41-
res = prob.u0 isa Number ? first(sol.FixedPoint_) : sol.FixedPoint_
42-
if res === missing
23+
if sol.FixedPoint_ === missing
24+
u0 = prob.u0 isa Number ? u0[1] : u0
4325
resid = NonlinearSolve.evaluate_f(prob, u0)
4426
res = u0
4527
converged = false
4628
else
29+
res = prob.u0 isa Number ? first(sol.FixedPoint_) :
30+
reshape(sol.FixedPoint_, size(prob.u0))
4731
resid = NonlinearSolve.evaluate_f(prob, res)
4832
converged = maximum(abs, resid) tol
4933
end

ext/NonlinearSolveMINPACKExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
3131
original = MINPACK.fsolve(f!, u0, m; tol, show_trace, tracing, method,
3232
iterations = maxiters)
3333
else
34-
jac! = @closure((J, u) -> (jac!_(J, u); Cint(0)))
34+
jac! = @closure((J, u)->(jac!_(J, u); Cint(0)))
3535
original = MINPACK.fsolve(f!, jac!, u0, m; tol, show_trace, tracing, method,
3636
iterations = maxiters)
3737
end

ext/NonlinearSolveSIAMFANLEquationsExt.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg
5858

5959
retcode = __siam_fanl_equations_retcode_mapping(sol)
6060
stats = __siam_fanl_equations_stats_mapping(method, sol)
61-
return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode,
61+
resid = NonlinearSolve.evaluate_f(prob, sol.solution)
62+
return SciMLBase.build_solution(prob, alg, sol.solution, resid; retcode,
6263
stats, original = sol)
6364
end
6465

@@ -87,7 +88,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg
8788

8889
retcode = __siam_fanl_equations_retcode_mapping(sol)
8990
stats = __siam_fanl_equations_stats_mapping(method, sol)
90-
return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode,
91+
resid = NonlinearSolve.evaluate_f(prob, sol.solution)
92+
return SciMLBase.build_solution(prob, alg, sol.solution, resid; retcode,
9193
stats, original = sol)
9294
end
9395

@@ -116,7 +118,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg
116118

117119
retcode = __siam_fanl_equations_retcode_mapping(sol)
118120
stats = __siam_fanl_equations_stats_mapping(method, sol)
119-
return SciMLBase.build_solution(prob, alg, sol.solution, sol.history; retcode, stats,
121+
resid = NonlinearSolve.evaluate_f(prob, sol.solution)
122+
return SciMLBase.build_solution(prob, alg, sol.solution, resid; retcode, stats,
120123
original = sol)
121124
end
122125

ext/NonlinearSolveSpeedMappingExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SpeedMappingJL, args...;
1717
sol = speedmapping(u0; m!, tol, Lp = Inf, maps_limit = maxiters, alg.orders,
1818
alg.check_obj, store_info, alg.σ_min, alg.stabilize)
1919
res = prob.u0 isa Number ? first(sol.minimizer) : sol.minimizer
20-
resid = NonlinearSolve.evaluate_f(prob, sol.minimizer)
20+
resid = NonlinearSolve.evaluate_f(prob, res)
2121

2222
return SciMLBase.build_solution(prob, alg, res, resid;
2323
retcode = sol.converged ? ReturnCode.Success : ReturnCode.Failure,

src/function_wrappers.jl

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22
# downstream packages. Make conversion to those easier.
33
function __construct_f(prob; alias_u0::Bool = false, can_handle_oop::Val{OOP} = Val(false),
44
can_handle_scalar::Val{SCALAR} = Val(false), make_fixed_point::Val{FP} = Val(false),
5-
can_handle_arbitrary_dims::Val{DIMS} = Val(false)) where {SCALAR, OOP, DIMS, FP}
5+
can_handle_arbitrary_dims::Val{DIMS} = Val(false),
6+
force_oop::Val{FOOP} = Val(false)) where {SCALAR, OOP, DIMS, FP, FOOP}
67
if !OOP && SCALAR
78
error("Incorrect Specification: OOP not supported but scalar supported.")
89
end
910

11+
resid = evaluate_f(prob, prob.u0)
12+
1013
if SCALAR || !(prob.u0 isa Number)
1114
u0 = __maybe_unaliased(prob.u0, alias_u0)
1215
else
@@ -20,51 +23,69 @@ function __construct_f(prob; alias_u0::Bool = false, can_handle_oop::Val{OOP} =
2023
@. du += u
2124
end
2225
else
23-
@closure (du, u, p) -> prob.f(du, u, p) .+ u
26+
@closure (u, p) -> prob.f(u, p) .+ u
2427
end
2528
else
2629
prob.f
2730
end
2831

29-
f_final = if isinplace(prob)
32+
ff = if isinplace(prob)
33+
ninputs = 2
3034
if DIMS || u0 isa AbstractVector
3135
@closure (du, u) -> (f(du, u, prob.p); du)
3236
else
3337
u0_size = size(u0)
34-
du_size = size(evaluate_f(prob, u0))
38+
du_size = size(resid)
3539
@closure (du, u) -> (f(reshape(du, du_size), reshape(u, u0_size), prob.p); du)
3640
end
3741
else
3842
if prob.u0 isa Number
3943
if SCALAR
40-
@closure (u) -> prob.f(u, prob.p)
44+
ninputs = 1
45+
@closure (u) -> f(u, prob.p)
4146
elseif OOP
42-
@closure (u) -> [prob.f(first(u), prob.p)]
47+
ninputs = 1
48+
@closure (u) -> [f(first(u), prob.p)]
4349
else
44-
@closure (du, u) -> (du[1] = prob.f(first(u), prob.p); du)
50+
ninputs = 2
51+
resid = [resid]
52+
@closure (du, u) -> (du[1] = f(first(u), prob.p); du)
4553
end
4654
else
4755
if OOP
56+
ninputs = 1
4857
if DIMS
49-
@closure (u) -> prob.f(u, prob.p)
58+
@closure (u) -> f(u, prob.p)
5059
else
5160
u0_size = size(u0)
52-
@closure (u) -> _vec(prob.f(reshape(u, u0_size), prob.p))
61+
@closure (u) -> _vec(f(reshape(u, u0_size), prob.p))
5362
end
5463
else
64+
ninputs = 2
5565
if DIMS
56-
@closure (du, u) -> (copyto!(du, prob.f(u, prob.p)); du)
66+
@closure (du, u) -> (copyto!(du, f(u, prob.p)); du)
5767
else
5868
u0_size = size(u0)
5969
@closure (du, u) -> begin
60-
copyto!(vec(du), vec(prob.f(reshape(u, u0_size), prob.p)))
70+
copyto!(vec(du), vec(f(reshape(u, u0_size), prob.p)))
6171
return du
6272
end
6373
end
6474
end
6575
end
6676
end
6777

78+
f_final = if FOOP
79+
if ninputs == 1
80+
ff
81+
else
82+
du_ = DIMS ? similar(resid) : _vec(similar(resid))
83+
@closure (u) -> (ff(du_, u); du_)
84+
end
85+
else
86+
ff
87+
end
88+
6889
return f_final, ifelse(DIMS, u0, _vec(u0))
6990
end
7091

File renamed without changes.

0 commit comments

Comments
 (0)