Skip to content

Commit e4534ed

Browse files
author
Avik Pal
committed
Add inplace and cached ForwardDiff rules
1 parent f184a67 commit e4534ed

File tree

2 files changed

+133
-34
lines changed

2 files changed

+133
-34
lines changed

src/ad.jl

Lines changed: 85 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,78 @@ function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, <:AbstractArray}
44
kwargs...) where {T, V, P, iip}
55
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
66
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
7-
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)
7+
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats,
8+
sol.original)
89
end
910

10-
# Differentiate Out-of-Place Nonlinear Root Finding Problems
11-
function __nlsolve_ad(prob::NonlinearProblem{uType, false}, alg, args...;
12-
kwargs...) where {uType}
11+
@concrete mutable struct NonlinearSolveForwardDiffCache
12+
cache
13+
prob
14+
alg
15+
p
16+
values_p
17+
partials_p
18+
end
19+
20+
@inline function __has_duals(::Union{<:Dual{T, V, P},
21+
<:AbstractArray{<:Dual{T, V, P}}}) where {T, V, P}
22+
return true
23+
end
24+
@inline __has_duals(::Any) = false
25+
26+
function SciMLBase.reinit!(cache::NonlinearSolveForwardDiffCache; p = cache.p,
27+
u0 = get_u(cache.cache), kwargs...)
28+
inner_cache = SciMLBase.reinit!(cache.cache; p = value(p), u0 = value(u0), kwargs...)
29+
cache.cache = inner_cache
30+
cache.p = p
31+
cache.values_p = value(p)
32+
cache.partials_p = ForwardDiff.partials(p)
33+
return cache
34+
end
35+
36+
function SciMLBase.init(prob::NonlinearProblem{<:Union{Number, <:AbstractArray},
37+
iip, <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
38+
alg::Union{Nothing, AbstractNonlinearAlgorithm}, args...;
39+
kwargs...) where {T, V, P, iip}
40+
p = value(prob.p)
41+
newprob = NonlinearProblem(prob.f, value(prob.u0), p; prob.kwargs...)
42+
cache = init(newprob, alg, args...; kwargs...)
43+
return NonlinearSolveForwardDiffCache(cache, newprob, alg, prob.p, p,
44+
ForwardDiff.partials(prob.p))
45+
end
46+
47+
function SciMLBase.solve!(cache::NonlinearSolveForwardDiffCache)
48+
sol = solve!(cache.cache)
49+
prob = cache.prob
50+
51+
uu = sol.u
52+
f_p = __nlsolve_∂f_∂p(prob, prob.f, uu, cache.values_p)
53+
f_x = __nlsolve_∂f_∂u(prob, prob.f, uu, cache.values_p)
54+
55+
z_arr = -f_x \ f_p
56+
57+
sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z)
58+
if cache.p isa Number
59+
partials = sumfun((z_arr, cache.p))
60+
else
61+
partials = sum(sumfun, zip(eachcol(z_arr), cache.p))
62+
end
63+
64+
dual_soln = __nlsolve_dual_soln(sol.u, partials, cache.p)
65+
return SciMLBase.build_solution(prob, cache.alg, dual_soln, sol.resid; sol.retcode,
66+
sol.stats, sol.original)
67+
end
68+
69+
function __nlsolve_ad(prob::NonlinearProblem{uType, iip}, alg, args...;
70+
kwargs...) where {uType, iip}
1371
p = value(prob.p)
1472
newprob = NonlinearProblem(prob.f, value(prob.u0), p; prob.kwargs...)
1573

1674
sol = solve(newprob, alg, args...; kwargs...)
1775

1876
uu = sol.u
19-
f_p = __nlsolve_∂f_∂p(prob.f, uu, p)
20-
f_x = __nlsolve_∂f_∂u(prob.f, uu, p)
77+
f_p = __nlsolve_∂f_∂p(prob, prob.f, uu, p)
78+
f_x = __nlsolve_∂f_∂u(prob, prob.f, uu, p)
2179

2280
z_arr = -f_x \ f_p
2381

@@ -34,8 +92,16 @@ function __nlsolve_ad(prob::NonlinearProblem{uType, false}, alg, args...;
3492
return sol, partials
3593
end
3694

37-
@inline function __nlsolve_∂f_∂p(f::F, u, p) where {F}
38-
__f = Base.Fix1(f, u)
95+
@inline function __nlsolve_∂f_∂p(prob, f::F, u, p) where {F}
96+
if isinplace(prob)
97+
__f = p -> begin
98+
du = similar(u, promote_type(eltype(u), eltype(p)))
99+
f(du, u, p)
100+
return du
101+
end
102+
else
103+
__f = Base.Fix1(f, u)
104+
end
39105
if p isa Number
40106
return __reshape(ForwardDiff.derivative(__f, p), :, 1)
41107
elseif u isa Number
@@ -45,12 +111,18 @@ end
45111
end
46112
end
47113

48-
@inline function __nlsolve_∂f_∂u(f::F, u, p) where {F}
49-
__f = Base.Fix2(f, p)
50-
if u isa Number
51-
return ForwardDiff.derivative(__f, u)
114+
@inline function __nlsolve_∂f_∂u(prob, f::F, u, p) where {F}
115+
if isinplace(prob)
116+
du = similar(u)
117+
__f = (du, u) -> f(du, u, p)
118+
ForwardDiff.jacobian(__f, du, u)
52119
else
53-
return ForwardDiff.jacobian(__f, u)
120+
__f = Base.Fix2(f, p)
121+
if u isa Number
122+
return ForwardDiff.derivative(__f, u)
123+
else
124+
return ForwardDiff.jacobian(__f, u)
125+
end
54126
end
55127
end
56128

test/forward_ad.jl

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,33 @@ jacobian_f(::Number, p::Number) = 1 / (2 * √p)
88
jacobian_f(u, p::Number) = one.(u) .* (1 / (2 * p))
99
jacobian_f(u, p::AbstractArray) = diagm(vec(@. 1 / (2 * p)))
1010

11-
function solve_with(::Val{iip}, u, alg) where {iip}
12-
f = if iip
11+
function solve_with(::Val{mode}, u, alg) where {mode}
12+
f = if mode === :iip
1313
solve_iip(p) = solve(NonlinearProblem(test_f!, u, p), alg).u
14-
else
14+
elseif mode === :iip_cache
15+
function solve_iip_init(p)
16+
cache = SciMLBase.init(NonlinearProblem(test_f!, u, p), alg)
17+
return SciMLBase.solve!(cache).u
18+
end
19+
elseif mode === :oop
1520
solve_oop(p) = solve(NonlinearProblem(test_f, u, p), alg).u
21+
elseif mode === :oop_cache
22+
function solve_oop_init(p)
23+
cache = SciMLBase.init(NonlinearProblem(test_f, u, p), alg)
24+
return SciMLBase.solve!(cache).u
25+
end
1626
end
1727
return f
1828
end
1929

20-
__can_inplace(::Number) = false
21-
__can_inplace(::AbstractArray) = true
22-
__can_inplace(::StaticArray) = false
30+
__compatible(::Any, ::Val{:oop}) = true
31+
__compatible(::Any, ::Val{:oop_cache}) = true
32+
__compatible(::Number, ::Val{:iip}) = false
33+
__compatible(::AbstractArray, ::Val{:iip}) = true
34+
__compatible(::StaticArray, ::Val{:iip}) = false
35+
__compatible(::Number, ::Val{:iip_cache}) = false
36+
__compatible(::AbstractArray, ::Val{:iip_cache}) = true
37+
__compatible(::StaticArray, ::Val{:iip_cache}) = false
2338

2439
__compatible(::Any, ::Number) = true
2540
__compatible(::Number, ::AbstractArray) = false
@@ -32,37 +47,49 @@ __compatible(u::StaticArray, ::SciMLBase.AbstractNonlinearAlgorithm) = true
3247
__compatible(u::StaticArray, ::Union{CMINPACK, NLsolveJL}) = false
3348
__compatible(u, ::Nothing) = true
3449

50+
__compatible(::Any, ::Any) = true
51+
__compatible(::CMINPACK, ::Val{:iip_cache}) = false
52+
__compatible(::CMINPACK, ::Val{:oop_cache}) = false
53+
__compatible(::NLsolveJL, ::Val{:iip_cache}) = false
54+
__compatible(::NLsolveJL, ::Val{:oop_cache}) = false
55+
3556
@testset "ForwardDiff.jl Integration: $(alg)" for alg in (NewtonRaphson(), TrustRegion(),
3657
LevenbergMarquardt(), PseudoTransient(; alpha_initial = 10.0), Broyden(), Klement(),
3758
DFSane(), nothing, NLsolveJL(), CMINPACK())
3859
us = (2.0, @SVector[1.0, 1.0], [1.0, 1.0], ones(2, 2), @SArray ones(2, 2))
3960

4061
@testset "Scalar AD" begin
41-
for p in 1.0:0.1:100.0
42-
for u0 in us
43-
__compatible(u0, alg) || continue
44-
sol = solve(NonlinearProblem(test_f, u0, p), alg)
45-
if SciMLBase.successful_retcode(sol)
46-
gs = abs.(ForwardDiff.derivative(solve_with(Val{false}(), u0, alg), p))
47-
gs_true = abs.(jacobian_f(u0, p))
48-
if !(isapprox(gs, gs_true, atol = 1e-5))
49-
@show sol.retcode, sol.u
50-
@error "ForwardDiff Failed for u0=$(u0) and p=$(p) with $(alg)" forwardiff_gradient=gs true_gradient=gs_true
51-
else
52-
@test abs.(gs)abs.(gs_true) atol=1e-5
53-
end
62+
for p in 1.0:0.1:100.0, u0 in us, mode in (:iip, :oop, :iip_cache, :oop_cache)
63+
__compatible(u0, alg) || continue
64+
__compatible(u0, Val(mode)) || continue
65+
__compatible(alg, Val(mode)) || continue
66+
67+
sol = solve(NonlinearProblem(test_f, u0, p), alg)
68+
if SciMLBase.successful_retcode(sol)
69+
gs = abs.(ForwardDiff.derivative(solve_with(Val{mode}(), u0, alg), p))
70+
gs_true = abs.(jacobian_f(u0, p))
71+
if !(isapprox(gs, gs_true, atol = 1e-5))
72+
@show sol.retcode, sol.u
73+
@error "ForwardDiff Failed for u0=$(u0) and p=$(p) with $(alg)" forwardiff_gradient=gs true_gradient=gs_true
74+
else
75+
@test abs.(gs)abs.(gs_true) atol=1e-5
5476
end
5577
end
5678
end
5779
end
5880

5981
@testset "Jacobian" begin
60-
for u0 in us, p in ([2.0, 1.0], [2.0 1.0; 3.0 4.0])
82+
for u0 in us, p in ([2.0, 1.0], [2.0 1.0; 3.0 4.0]),
83+
mode in (:iip, :oop, :iip_cache, :oop_cache)
84+
6185
__compatible(u0, p) || continue
6286
__compatible(u0, alg) || continue
87+
__compatible(u0, Val(mode)) || continue
88+
__compatible(alg, Val(mode)) || continue
89+
6390
sol = solve(NonlinearProblem(test_f, u0, p), alg)
6491
if SciMLBase.successful_retcode(sol)
65-
gs = abs.(ForwardDiff.jacobian(solve_with(Val{false}(), u0, alg), p))
92+
gs = abs.(ForwardDiff.jacobian(solve_with(Val{mode}(), u0, alg), p))
6693
gs_true = abs.(jacobian_f(u0, p))
6794
if !(isapprox(gs, gs_true, atol = 1e-5))
6895
@show sol.retcode, sol.u

0 commit comments

Comments
 (0)