1
- function scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
2
- f = prob. f
1
+ function SciMLBase. solve (prob:: NonlinearProblem {<: Union{Number, <:AbstractArray} ,
2
+ iip, <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} },
3
+ alg:: Union{Nothing, AbstractNonlinearAlgorithm} , args... ;
4
+ kwargs... ) where {T, V, P, iip}
5
+ sol, partials = __nlsolve_ad (prob, alg, args... ; kwargs... )
6
+ dual_soln = __nlsolve_dual_soln (sol. u, partials, prob. p)
7
+ return SciMLBase. build_solution (prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats,
8
+ sol. original)
9
+ end
10
+
11
+ @concrete mutable struct NonlinearSolveForwardDiffCache
12
+ cache
13
+ prob
14
+ alg
15
+ p
16
+ values_p
17
+ partials_p
18
+ end
19
+
20
+ @inline function __has_duals (:: Union {<: Dual{T, V, P} ,
21
+ <: AbstractArray{<:Dual{T, V, P}} }) where {T, V, P}
22
+ return true
23
+ end
24
+ @inline __has_duals (:: Any ) = false
25
+
26
+ function SciMLBase. reinit! (cache:: NonlinearSolveForwardDiffCache ; p = cache. p,
27
+ u0 = get_u (cache. cache), kwargs... )
28
+ inner_cache = SciMLBase. reinit! (cache. cache; p = value (p), u0 = value (u0), kwargs... )
29
+ cache. cache = inner_cache
30
+ cache. p = p
31
+ cache. values_p = value (p)
32
+ cache. partials_p = ForwardDiff. partials (p)
33
+ return cache
34
+ end
35
+
36
+ function SciMLBase. init (prob:: NonlinearProblem {<: Union{Number, <:AbstractArray} ,
37
+ iip, <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} },
38
+ alg:: Union{Nothing, AbstractNonlinearAlgorithm} , args... ;
39
+ kwargs... ) where {T, V, P, iip}
3
40
p = value (prob. p)
4
- u0 = value (prob. u0)
5
- newprob = NonlinearProblem (f, u0, p; prob. kwargs... )
41
+ newprob = NonlinearProblem (prob. f, value (prob. u0), p; prob. kwargs... )
42
+ cache = init (newprob, alg, args... ; kwargs... )
43
+ return NonlinearSolveForwardDiffCache (cache, newprob, alg, prob. p, p,
44
+ ForwardDiff. partials (prob. p))
45
+ end
46
+
47
+ function SciMLBase. solve! (cache:: NonlinearSolveForwardDiffCache )
48
+ sol = solve! (cache. cache)
49
+ prob = cache. prob
50
+
51
+ uu = sol. u
52
+ f_p = __nlsolve_∂f_∂p (prob, prob. f, uu, cache. values_p)
53
+ f_x = __nlsolve_∂f_∂u (prob, prob. f, uu, cache. values_p)
54
+
55
+ z_arr = - f_x \ f_p
56
+
57
+ sumfun = ((z, p),) -> map (zᵢ -> zᵢ * ForwardDiff. partials (p), z)
58
+ if cache. p isa Number
59
+ partials = sumfun ((z_arr, cache. p))
60
+ else
61
+ partials = sum (sumfun, zip (eachcol (z_arr), cache. p))
62
+ end
63
+
64
+ dual_soln = __nlsolve_dual_soln (sol. u, partials, cache. p)
65
+ return SciMLBase. build_solution (prob, cache. alg, dual_soln, sol. resid; sol. retcode,
66
+ sol. stats, sol. original)
67
+ end
68
+
69
+ function __nlsolve_ad (prob:: NonlinearProblem{uType, iip} , alg, args... ;
70
+ kwargs... ) where {uType, iip}
71
+ p = value (prob. p)
72
+ newprob = NonlinearProblem (prob. f, value (prob. u0), p; prob. kwargs... )
6
73
7
74
sol = solve (newprob, alg, args... ; kwargs... )
8
75
9
76
uu = sol. u
10
- f_p = scalar_nlsolve_ ∂f_∂p (f, uu, p)
11
- f_x = scalar_nlsolve_ ∂f_∂u (f, uu, p)
77
+ f_p = __nlsolve_ ∂f_∂p (prob, prob . f, uu, p)
78
+ f_x = __nlsolve_ ∂f_∂u (prob, prob . f, uu, p)
12
79
13
- z_arr = - inv ( f_x) * f_p
80
+ z_arr = - f_x \ f_p
14
81
15
82
pp = prob. p
16
83
sumfun = ((z, p),) -> map (zᵢ -> zᵢ * ForwardDiff. partials (p), z)
@@ -25,39 +92,47 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
25
92
return sol, partials
26
93
end
27
94
28
- function SciMLBase. solve (prob:: NonlinearProblem {<: Union{Number, SVector, <:AbstractArray} ,
29
- false , <: Dual{T, V, P} }, alg:: AbstractNonlinearSolveAlgorithm , args... ;
30
- kwargs... ) where {T, V, P}
31
- sol, partials = scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
32
- dual_soln = scalar_nlsolve_dual_soln (sol. u, partials, prob. p)
33
- return SciMLBase. build_solution (prob, alg, dual_soln, sol. resid; sol. retcode)
34
- end
35
-
36
- function SciMLBase. solve (prob:: NonlinearProblem {<: Union{Number, SVector, <:AbstractArray} ,
37
- false , <: AbstractArray{<:Dual{T, V, P}} }, alg:: AbstractNonlinearSolveAlgorithm ,
38
- args... ; kwargs... ) where {T, V, P}
39
- sol, partials = scalar_nlsolve_ad (prob, alg, args... ; kwargs... )
40
- dual_soln = scalar_nlsolve_dual_soln (sol. u, partials, prob. p)
41
- return SciMLBase. build_solution (prob, alg, dual_soln, sol. resid; sol. retcode)
42
- end
43
-
44
- function scalar_nlsolve_∂f_∂p (f, u, p)
45
- ff = p isa Number ? ForwardDiff. derivative :
46
- (u isa Number ? ForwardDiff. gradient : ForwardDiff. jacobian)
47
- return ff (Base. Fix1 (f, u), p)
95
+ @inline function __nlsolve_∂f_∂p (prob, f:: F , u, p) where {F}
96
+ if isinplace (prob)
97
+ __f = p -> begin
98
+ du = similar (u, promote_type (eltype (u), eltype (p)))
99
+ f (du, u, p)
100
+ return du
101
+ end
102
+ else
103
+ __f = Base. Fix1 (f, u)
104
+ end
105
+ if p isa Number
106
+ return __reshape (ForwardDiff. derivative (__f, p), :, 1 )
107
+ elseif u isa Number
108
+ return __reshape (ForwardDiff. gradient (__f, p), 1 , :)
109
+ else
110
+ return ForwardDiff. jacobian (__f, p)
111
+ end
48
112
end
49
113
50
- function scalar_nlsolve_∂f_∂u (f, u, p)
51
- ff = u isa Number ? ForwardDiff. derivative : ForwardDiff. jacobian
52
- return ff (Base. Fix2 (f, p), u)
114
+ @inline function __nlsolve_∂f_∂u (prob, f:: F , u, p) where {F}
115
+ if isinplace (prob)
116
+ du = similar (u)
117
+ __f = (du, u) -> f (du, u, p)
118
+ ForwardDiff. jacobian (__f, du, u)
119
+ else
120
+ __f = Base. Fix2 (f, p)
121
+ if u isa Number
122
+ return ForwardDiff. derivative (__f, u)
123
+ else
124
+ return ForwardDiff. jacobian (__f, u)
125
+ end
126
+ end
53
127
end
54
128
55
- function scalar_nlsolve_dual_soln (u:: Number , partials,
129
+ @inline function __nlsolve_dual_soln (u:: Number , partials,
56
130
:: Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}} ) where {T, V, P}
57
131
return Dual {T, V, P} (u, partials)
58
132
end
59
133
60
- function scalar_nlsolve_dual_soln (u:: AbstractArray , partials,
134
+ @inline function __nlsolve_dual_soln (u:: AbstractArray , partials,
61
135
:: Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}} ) where {T, V, P}
62
- return map (((uᵢ, pᵢ),) -> Dual {T, V, P} (uᵢ, pᵢ), zip (u, partials))
136
+ _partials = _restructure (u, partials)
137
+ return map (((uᵢ, pᵢ),) -> Dual {T, V, P} (uᵢ, pᵢ), zip (u, _partials))
63
138
end
0 commit comments