Skip to content

Commit 96b40e1

Browse files
committed
convert to immutable in ad dispatch
1 parent d309aea commit 96b40e1

File tree

1 file changed

+23
-12
lines changed
  • lib/SimpleNonlinearSolve/src

1 file changed

+23
-12
lines changed

lib/SimpleNonlinearSolve/src/ad.jl

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,26 @@
1-
for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem)
2-
@eval function SciMLBase.solve(
3-
prob::$(pType){<:Union{Number, <:AbstractArray}, iip,
4-
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
5-
alg::AbstractSimpleNonlinearSolveAlgorithm,
6-
args...;
7-
kwargs...) where {T, V, P, iip}
8-
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
9-
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
10-
return SciMLBase.build_solution(
11-
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
12-
end
1+
function SciMLBase.solve(
2+
prob::NonlinearLeastSquaresProblem{<:Union{Number, <:AbstractArray}, iip,
3+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
4+
alg::AbstractSimpleNonlinearSolveAlgorithm,
5+
args...;
6+
kwargs...) where {T, V, P, iip}
7+
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
8+
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
9+
return SciMLBase.build_solution(
10+
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
11+
end
12+
13+
function SciMLBase.solve(
14+
prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip,
15+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
16+
alg::AbstractSimpleNonlinearSolveAlgorithm,
17+
args...;
18+
kwargs...) where {T, V, P, iip}
19+
prob = convert(ImmutableNonlinearProblem, prob)
20+
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
21+
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
22+
return SciMLBase.build_solution(
23+
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
1324
end
1425

1526
for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)

0 commit comments

Comments
 (0)