Skip to content

Commit e30b10d

Browse files
authored
Merge pull request #384 from SciML/ap/fix_internalnorm
Patches for DiffEqCallbacks
2 parents 156e65b + 8767492 commit e30b10d

File tree

4 files changed

+30
-21
lines changed

4 files changed

+30
-21
lines changed

.github/workflows/Downstream.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ jobs:
2525
- {user: SciML, repo: OrdinaryDiffEq.jl, group: Interface}
2626
- {user: SciML, repo: OrdinaryDiffEq.jl, group: Regression}
2727
- {user: SciML, repo: BoundaryValueDiffEq.jl, group: All}
28+
- {user: SciML, repo: DiffEqCallbacks.jl, group: All}
2829
steps:
2930
- uses: actions/checkout@v4
3031
- uses: julia-actions/setup-julia@v1

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NonlinearSolve"
22
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
33
authors = ["SciML"]
4-
version = "3.7.1"
4+
version = "3.7.2"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/core/generic.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,13 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
2424
update_trace!(cache.trace, get_nsteps(cache), get_u(cache),
2525
get_fu(cache), nothing, nothing, nothing; last = True)
2626

27-
stats = ImmutableNLStats(get_nf(cache), get_njacs(cache), get_nfactors(cache),
28-
get_nsolve(cache), get_nsteps(cache))
27+
return SciMLBase.build_solution(cache.prob, cache.alg, get_u(cache), get_fu(cache);
28+
cache.retcode, stats = __compile_stats(cache), cache.trace)
29+
end
2930

30-
return SciMLBase.build_solution(cache.prob, cache.alg, get_u(cache),
31-
get_fu(cache); cache.retcode, stats, cache.trace)
31+
function __compile_stats(cache::AbstractNonlinearSolveCache)
32+
return ImmutableNLStats(get_nf(cache), get_njacs(cache), get_nfactors(cache),
33+
get_nsolve(cache), get_nsteps(cache))
3234
end
3335

3436
"""

src/default.jl

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ end
5656
retcode::ReturnCode.T
5757
force_stop::Bool
5858
maxiters::Int
59+
internalnorm
5960
end
6061

6162
function Base.show(
@@ -80,10 +81,13 @@ end
8081
for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProblem, :NLLS))
8182
algType = NonlinearSolvePolyAlgorithm{pType}
8283
@eval begin
83-
function SciMLBase.__init(prob::$probType, alg::$algType{N}, args...;
84-
maxtime = nothing, maxiters = 1000, kwargs...) where {N}
84+
function SciMLBase.__init(
85+
prob::$probType, alg::$algType{N}, args...; maxtime = nothing,
86+
maxiters = 1000, internalnorm = DEFAULT_NORM, kwargs...) where {N}
8587
return NonlinearSolvePolyAlgorithmCache{isinplace(prob), N, maxtime !== nothing}(
86-
map(solver -> SciMLBase.__init(prob, solver, args...; maxtime, kwargs...),
88+
map(
89+
solver -> SciMLBase.__init(
90+
prob, solver, args...; maxtime, internalnorm, kwargs...),
8791
alg.algs),
8892
alg,
8993
-1,
@@ -93,7 +97,8 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
9397
maxtime,
9498
ReturnCode.Default,
9599
false,
96-
maxiters)
100+
maxiters,
101+
internalnorm)
97102
end
98103
end
99104
end
@@ -134,8 +139,8 @@ end
134139
push!(calls,
135140
quote
136141
fus = tuple($(Tuple(resids)...))
137-
minfu, idx = __findmin(cache.caches[1].internalnorm, fus)
138-
stats = cache.caches[idx].stats
142+
minfu, idx = __findmin(cache.internalnorm, fus)
143+
stats = __compile_stats(cache.caches[idx])
139144
u = get_u(cache.caches[idx])
140145
retcode = cache.caches[idx].retcode
141146

@@ -171,16 +176,15 @@ end
171176
end)
172177
end
173178

174-
push!(calls,
175-
quote
176-
if !(1 cache.current length(cache.caches))
177-
minfu, idx = __findmin(first(cache.caches).internalnorm, cache.caches)
178-
cache.best = idx
179-
cache.retcode = cache.caches[cache.best].retcode
180-
cache.force_stop = true
181-
return
182-
end
183-
end)
179+
push!(calls, quote
180+
if !(1 cache.current length(cache.caches))
181+
minfu, idx = __findmin(cache.internalnorm, cache.caches)
182+
cache.best = idx
183+
cache.retcode = cache.caches[cache.best].retcode
184+
cache.force_stop = true
185+
return
186+
end
187+
end)
184188

185189
return Expr(:block, calls...)
186190
end
@@ -353,9 +357,11 @@ function FastShortcutNLLSPolyalg(::Type{T} = Float64; concrete_jac = nothing,
353357
linsolve = nothing, precs = DEFAULT_PRECS, kwargs...) where {T}
354358
if __is_complex(T)
355359
algs = (GaussNewton(; concrete_jac, linsolve, precs, kwargs...),
360+
LevenbergMarquardt(; linsolve, precs, disable_geodesic = Val(true), kwargs...),
356361
LevenbergMarquardt(; linsolve, precs, kwargs...))
357362
else
358363
algs = (GaussNewton(; concrete_jac, linsolve, precs, kwargs...),
364+
LevenbergMarquardt(; linsolve, precs, disable_geodesic = Val(true), kwargs...),
359365
TrustRegion(; concrete_jac, linsolve, precs, kwargs...),
360366
GaussNewton(; concrete_jac, linsolve, precs,
361367
linesearch = LineSearchesJL(; method = BackTracking()), kwargs...),

0 commit comments

Comments
 (0)