Skip to content

Commit 133b1d5

Browse files
committed
handle complex numbers correctly by default
1 parent aa282b7 commit 133b1d5

File tree

4 files changed

+145
-59
lines changed

4 files changed

+145
-59
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "3.1.1"
4+
version = "3.1.2"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -71,6 +71,7 @@ MaybeInplace = "0.1.1"
7171
NLsolve = "4.5"
7272
NaNMath = "1"
7373
NonlinearProblemLibrary = "0.1.1"
74+
OrdinaryDiffEq = "6"
7475
Pkg = "1"
7576
PrecompileTools = "1.2"
7677
Printf = "<0.0.1, 1"
@@ -106,6 +107,7 @@ MINPACK = "4854310b-de5a-5eb6-a2a5-c1dee2bd17f9"
106107
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
107108
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
108109
NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
110+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
109111
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
110112
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
111113
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
@@ -117,4 +119,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
117119
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
118120

119121
[targets]
120-
test = ["Aqua", "Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "StableRNGs", "MINPACK", "NLsolve"]
122+
test = ["Aqua", "Enzyme", "BenchmarkTools", "SafeTestsets", "Pkg", "Test", "ForwardDiff", "StaticArrays", "Symbolics", "LinearSolve", "Random", "LinearAlgebra", "Zygote", "SparseDiffTools", "NonlinearProblemLibrary", "LeastSquaresOptim", "FastLevenbergMarquardt", "NaNMath", "BandedMatrices", "DiffEqBase", "StableRNGs", "MINPACK", "NLsolve", "OrdinaryDiffEq"]

src/default.jl

Lines changed: 110 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,8 @@ function SciMLBase.reinit!(cache::NonlinearSolvePolyAlgorithmCache, args...; kwa
165165
end
166166

167167
"""
168-
RobustMultiNewton(; concrete_jac = nothing, linsolve = nothing, precs = DEFAULT_PRECS,
169-
autodiff = nothing)
168+
RobustMultiNewton(::Type{T} = Float64; concrete_jac = nothing, linsolve = nothing,
169+
precs = DEFAULT_PRECS, autodiff = nothing)
170170
171171
A polyalgorithm focused on robustness. It uses a mixture of Newton methods with different
172172
globalizing techniques (trust region updates, line searches, etc.) in order to find a
@@ -176,6 +176,11 @@ Basically, if this algorithm fails, then "most" good ways of solving your proble
176176
you may need to think about reformulating the model (either there is an issue with the model,
177177
or more precision / more stable linear solver choice is required).
178178
179+
### Arguments
180+
181+
- `T`: The eltype of the initial guess. It is only used to check if some of the algorithms
182+
are compatible with the problem type. Defaults to `Float64`.
183+
179184
### Keyword Arguments
180185
181186
- `autodiff`: determines the backend used for the Jacobian. Note that this argument is
@@ -196,28 +201,38 @@ or more precision / more stable linear solver choice is required).
196201
algorithms, consult the
197202
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
198203
"""
199-
function RobustMultiNewton(; concrete_jac = nothing, linsolve = nothing,
200-
precs = DEFAULT_PRECS, autodiff = nothing)
201-
algs = (TrustRegion(; concrete_jac, linsolve, precs),
202-
TrustRegion(; concrete_jac, linsolve, precs, autodiff,
203-
radius_update_scheme = RadiusUpdateSchemes.Bastin),
204-
NewtonRaphson(; concrete_jac, linsolve, precs, linesearch = BackTracking(),
205-
autodiff),
206-
TrustRegion(; concrete_jac, linsolve, precs,
207-
radius_update_scheme = RadiusUpdateSchemes.NLsolve, autodiff),
208-
TrustRegion(; concrete_jac, linsolve, precs,
209-
radius_update_scheme = RadiusUpdateSchemes.Fan, autodiff))
204+
function RobustMultiNewton(::Type{T} = Float64; concrete_jac = nothing, linsolve = nothing,
205+
precs = DEFAULT_PRECS, autodiff = nothing) where {T}
206+
if __is_complex(T)
207+
# Let's atleast have something here for complex numbers
208+
algs = (NewtonRaphson(; concrete_jac, linsolve, precs, autodiff),)
209+
else
210+
algs = (TrustRegion(; concrete_jac, linsolve, precs),
211+
TrustRegion(; concrete_jac, linsolve, precs, autodiff,
212+
radius_update_scheme = RadiusUpdateSchemes.Bastin),
213+
NewtonRaphson(; concrete_jac, linsolve, precs, linesearch = BackTracking(),
214+
autodiff),
215+
TrustRegion(; concrete_jac, linsolve, precs,
216+
radius_update_scheme = RadiusUpdateSchemes.NLsolve, autodiff),
217+
TrustRegion(; concrete_jac, linsolve, precs,
218+
radius_update_scheme = RadiusUpdateSchemes.Fan, autodiff))
219+
end
210220
return NonlinearSolvePolyAlgorithm(algs, Val(:NLS))
211221
end
212222

213223
"""
214-
FastShortcutNonlinearPolyalg(; concrete_jac = nothing, linsolve = nothing,
215-
precs = DEFAULT_PRECS, must_use_jacobian::Val = Val(false),
216-
prefer_simplenonlinearsolve::Val{SA} = Val(false), autodiff = nothing)
224+
FastShortcutNonlinearPolyalg(::Type{T} = Float64; concrete_jac = nothing,
225+
linsolve = nothing, precs = DEFAULT_PRECS, must_use_jacobian::Val = Val(false),
226+
prefer_simplenonlinearsolve::Val{SA} = Val(false), autodiff = nothing) where {T}
217227
218228
A polyalgorithm focused on balancing speed and robustness. It first tries less robust methods
219229
for more performance and then tries more robust techniques if the faster ones fail.
220230
231+
### Arguments
232+
233+
- `T`: The eltype of the initial guess. It is only used to check if some of the algorithms
234+
are compatible with the problem type. Defaults to `Float64`.
235+
221236
### Keyword Arguments
222237
223238
- `autodiff`: determines the backend used for the Jacobian. Note that this argument is
@@ -238,53 +253,76 @@ for more performance and then tries more robust techniques if the faster ones fa
238253
algorithms, consult the
239254
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
240255
"""
241-
function FastShortcutNonlinearPolyalg(; concrete_jac = nothing, linsolve = nothing,
242-
precs = DEFAULT_PRECS, must_use_jacobian::Val{JAC} = Val(false),
256+
function FastShortcutNonlinearPolyalg(::Type{T} = Float64; concrete_jac = nothing,
257+
linsolve = nothing, precs = DEFAULT_PRECS, must_use_jacobian::Val{JAC} = Val(false),
243258
prefer_simplenonlinearsolve::Val{SA} = Val(false),
244-
autodiff = nothing) where {JAC, SA}
259+
autodiff = nothing) where {T, JAC, SA}
245260
if JAC
246-
algs = (NewtonRaphson(; concrete_jac, linsolve, precs, autodiff),
247-
NewtonRaphson(; concrete_jac, linsolve, precs, linesearch = BackTracking(),
248-
autodiff),
249-
TrustRegion(; concrete_jac, linsolve, precs, autodiff),
250-
TrustRegion(; concrete_jac, linsolve, precs,
251-
radius_update_scheme = RadiusUpdateSchemes.Bastin, autodiff))
252-
else
253-
# SimpleNewtonRaphson and SimpleTrustRegion are not robust to singular Jacobians
254-
# and thus are not included in the polyalgorithm
255-
if SA
256-
algs = (SimpleBroyden(),
257-
Broyden(; init_jacobian = Val(:true_jacobian)),
258-
SimpleKlement(),
259-
NewtonRaphson(; concrete_jac, linsolve, precs, autodiff),
260-
NewtonRaphson(; concrete_jac, linsolve, precs, linesearch = BackTracking(),
261-
autodiff),
262-
NewtonRaphson(; concrete_jac, linsolve, precs, linesearch = BackTracking(),
263-
autodiff),
264-
TrustRegion(; concrete_jac, linsolve, precs,
265-
radius_update_scheme = RadiusUpdateSchemes.Bastin, autodiff))
261+
if __is_complex(T)
262+
algs = (NewtonRaphson(; concrete_jac, linsolve, precs, autodiff),)
266263
else
267-
algs = (Broyden(),
268-
Broyden(; init_jacobian = Val(:true_jacobian)),
269-
Klement(; linsolve, precs),
270-
NewtonRaphson(; concrete_jac, linsolve, precs, autodiff),
264+
algs = (NewtonRaphson(; concrete_jac, linsolve, precs, autodiff),
271265
NewtonRaphson(; concrete_jac, linsolve, precs, linesearch = BackTracking(),
272266
autodiff),
273267
TrustRegion(; concrete_jac, linsolve, precs, autodiff),
274268
TrustRegion(; concrete_jac, linsolve, precs,
275269
radius_update_scheme = RadiusUpdateSchemes.Bastin, autodiff))
276270
end
271+
else
272+
# SimpleNewtonRaphson and SimpleTrustRegion are not robust to singular Jacobians
273+
# and thus are not included in the polyalgorithm
274+
if SA
275+
if __is_complex(T)
276+
algs = (SimpleBroyden(),
277+
Broyden(; init_jacobian = Val(:true_jacobian)),
278+
SimpleKlement(),
279+
NewtonRaphson(; concrete_jac, linsolve, precs, autodiff))
280+
else
281+
algs = (SimpleBroyden(),
282+
Broyden(; init_jacobian = Val(:true_jacobian)),
283+
SimpleKlement(),
284+
NewtonRaphson(; concrete_jac, linsolve, precs, autodiff),
285+
NewtonRaphson(; concrete_jac, linsolve, precs,
286+
linesearch = BackTracking(), autodiff),
287+
NewtonRaphson(; concrete_jac, linsolve, precs,
288+
linesearch = BackTracking(), autodiff),
289+
TrustRegion(; concrete_jac, linsolve, precs,
290+
radius_update_scheme = RadiusUpdateSchemes.Bastin, autodiff))
291+
end
292+
else
293+
if __is_complex(T)
294+
algs = (Broyden(),
295+
Broyden(; init_jacobian = Val(:true_jacobian)),
296+
Klement(; linsolve, precs),
297+
NewtonRaphson(; concrete_jac, linsolve, precs, autodiff))
298+
else
299+
algs = (Broyden(),
300+
Broyden(; init_jacobian = Val(:true_jacobian)),
301+
Klement(; linsolve, precs),
302+
NewtonRaphson(; concrete_jac, linsolve, precs, autodiff),
303+
NewtonRaphson(; concrete_jac, linsolve, precs,
304+
linesearch = BackTracking(), autodiff),
305+
TrustRegion(; concrete_jac, linsolve, precs, autodiff),
306+
TrustRegion(; concrete_jac, linsolve, precs,
307+
radius_update_scheme = RadiusUpdateSchemes.Bastin, autodiff))
308+
end
309+
end
277310
end
278311
return NonlinearSolvePolyAlgorithm(algs, Val(:NLS))
279312
end
280313

281314
"""
282-
FastShortcutNLLSPolyalg(; concrete_jac = nothing, linsolve = nothing,
283-
precs = DEFAULT_PRECS, kwargs...)
315+
FastShortcutNLLSPolyalg(::Type{T} = Float64; concrete_jac = nothing, linsolve = nothing,
316+
precs = DEFAULT_PRECS, kwargs...)
284317
285318
A polyalgorithm focused on balancing speed and robustness. It first tries less robust methods
286319
for more performance and then tries more robust techniques if the faster ones fail.
287320
321+
### Arguments
322+
323+
- `T`: The eltype of the initial guess. It is only used to check if some of the algorithms
324+
are compatible with the problem type. Defaults to `Float64`.
325+
288326
### Keyword Arguments
289327
290328
- `autodiff`: determines the backend used for the Jacobian. Note that this argument is
@@ -305,42 +343,58 @@ for more performance and then tries more robust techniques if the faster ones fa
305343
algorithms, consult the
306344
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
307345
"""
308-
function FastShortcutNLLSPolyalg(; concrete_jac = nothing, linsolve = nothing,
309-
precs = DEFAULT_PRECS, kwargs...)
310-
algs = (GaussNewton(; concrete_jac, linsolve, precs, kwargs...),
311-
GaussNewton(; concrete_jac, linsolve, precs, linesearch = BackTracking(),
312-
kwargs...),
313-
LevenbergMarquardt(; concrete_jac, linsolve, precs, kwargs...))
346+
function FastShortcutNLLSPolyalg(::Type{T} = Float64; concrete_jac = nothing,
347+
linsolve = nothing, precs = DEFAULT_PRECS, kwargs...) where {T}
348+
if __is_complex(T)
349+
algs = (GaussNewton(; concrete_jac, linsolve, precs, kwargs...),
350+
LevenbergMarquardt(; concrete_jac, linsolve, precs, kwargs...))
351+
else
352+
algs = (GaussNewton(; concrete_jac, linsolve, precs, kwargs...),
353+
TrustRegion(; concrete_jac, linsolve, precs, kwargs...),
354+
GaussNewton(; concrete_jac, linsolve, precs, linesearch = BackTracking(),
355+
kwargs...),
356+
TrustRegion(; concrete_jac, linsolve, precs,
357+
radius_update_scheme = RadiusUpdateSchemes.Bastin, kwargs...),
358+
LevenbergMarquardt(; concrete_jac, linsolve, precs, kwargs...))
359+
end
314360
return NonlinearSolvePolyAlgorithm(algs, Val(:NLLS))
315361
end
316362

317363
## Defaults
318364

319365
## TODO: In the long run we want to use an `Assumptions` API like LinearSolve to specify
320366
## the conditioning of the Jacobian and such
367+
368+
## TODO: Currently some of the algorithms like LineSearches / TrustRegion don't support
369+
## complex numbers. We should use the `DiffEqBase` trait for this once all of the
370+
## NonlinearSolve algorithms support it. For now we just do a check and remove the
371+
## unsupported ones from default
372+
321373
## Defaults to a fast and robust poly algorithm in most cases. If the user went through
322374
## the trouble of specifying a custom jacobian function, we should use algorithms that
323375
## can use that!
324-
325376
function SciMLBase.__init(prob::NonlinearProblem, ::Nothing, args...; kwargs...)
326377
must_use_jacobian = Val(prob.f.jac !== nothing)
327-
return SciMLBase.__init(prob, FastShortcutNonlinearPolyalg(; must_use_jacobian),
378+
return SciMLBase.__init(prob,
379+
FastShortcutNonlinearPolyalg(eltype(prob.u0); must_use_jacobian),
328380
args...; kwargs...)
329381
end
330382

331383
function SciMLBase.__solve(prob::NonlinearProblem, ::Nothing, args...; kwargs...)
332384
must_use_jacobian = Val(prob.f.jac !== nothing)
333385
prefer_simplenonlinearsolve = Val(prob.u0 isa SArray)
334386
return SciMLBase.__solve(prob,
335-
FastShortcutNonlinearPolyalg(; must_use_jacobian,
387+
FastShortcutNonlinearPolyalg(eltype(prob.u0); must_use_jacobian,
336388
prefer_simplenonlinearsolve), args...; kwargs...)
337389
end
338390

339391
function SciMLBase.__init(prob::NonlinearLeastSquaresProblem, ::Nothing, args...; kwargs...)
340-
return SciMLBase.__init(prob, FastShortcutNLLSPolyalg(), args...; kwargs...)
392+
return SciMLBase.__init(prob, FastShortcutNLLSPolyalg(eltype(prob.u0)), args...;
393+
kwargs...)
341394
end
342395

343396
function SciMLBase.__solve(prob::NonlinearLeastSquaresProblem, ::Nothing, args...;
344397
kwargs...)
345-
return SciMLBase.__solve(prob, FastShortcutNLLSPolyalg(), args...; kwargs...)
398+
return SciMLBase.__solve(prob, FastShortcutNLLSPolyalg(eltype(prob.u0)), args...;
399+
kwargs...)
346400
end

src/utils.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,3 +494,8 @@ end
494494
@inline __diag(x::AbstractMatrix) = diag(x)
495495
@inline __diag(x::AbstractVector) = x
496496
@inline __diag(x::Number) = x
497+
498+
@inline __is_complex(::Type{ComplexF64}) = true
499+
@inline __is_complex(::Type{ComplexF32}) = true
500+
@inline __is_complex(::Type{Complex}) = true
501+
@inline __is_complex(::Type{T}) where {T} = false

test/polyalgs.jl

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using NonlinearSolve, Test, NaNMath
1+
using NonlinearSolve, Test, NaNMath, OrdinaryDiffEq
22

33
f(u, p) = u .* u .- 2
44
u0 = [1.0, 1.0]
@@ -58,3 +58,28 @@ p = 2.0
5858
prob = NonlinearProblem(ff_interval, u0, p)
5959
sol = solve(prob; abstol = 1e-9)
6060
@test SciMLBase.successful_retcode(sol)
61+
62+
# Shooting Problem: Taken from BoundaryValueDiffEq.jl
63+
# Testing for Complex Valued Root Finding. For Complex valued inputs we drop some of the
64+
# algorithms which dont support those.
65+
function ode_func!(du, u, p, t)
66+
du[1] = u[2]
67+
du[2] = -u[1]
68+
return nothing
69+
end
70+
71+
function objective_function!(resid, u0, p)
72+
odeprob = ODEProblem{true}(ode_func!, u0, (0.0, 100.0), p)
73+
sol = solve(odeprob, Tsit5(), abstol = 1e-9, reltol = 1e-9, verbose = false)
74+
resid[1] = sol(0.0)[1]
75+
resid[2] = sol(100.0)[1] - 1.0
76+
return nothing
77+
end
78+
79+
prob = NonlinearProblem{true}(objective_function!, [0.0, 1.0] .+ 1im)
80+
sol = solve(prob; abstol = 1e-10)
81+
@test SciMLBase.successful_retcode(sol)
82+
# This test is not meant to return success but test that all the default solvers can handle
83+
# complex valued problems
84+
@test_nowarn solve(prob; abstol = 1e-19, maxiters = 10)
85+
@test_nowarn solve(prob, RobustMultiNewton(eltype(prob.u0)); abstol = 1e-19, maxiters = 10)

0 commit comments

Comments
 (0)